208 lines
7.7 KiB
Python
208 lines
7.7 KiB
Python
# !/usr/bin/env/python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# pylint: disable=no-self-use,too-many-arguments,unused-argument,not-callable,no-member,attribute-defined-outside-init
|
|
""" Bert Custom Handler."""
|
|
from captum.attr import IntegratedGradients
|
|
import json
|
|
import logging
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
from transformers import BertTokenizer
|
|
from ts.torch_handler.base_handler import BaseHandler
|
|
from captum.attr import visualization
|
|
import torch.nn.functional as F
|
|
from bert_train import BertNewsClassifier
|
|
from wrapper import AGNewsmodelWrapper
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NewsClassifierHandler(BaseHandler):
|
|
"""
|
|
NewsClassifierHandler class. This handler takes a review / sentence
|
|
and returns the label as either world / sports / business /sci-tech
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
self.mapping = None
|
|
self.device = None
|
|
self.initialized = False
|
|
self.class_mapping_file = None
|
|
self.VOCAB_FILE = None
|
|
|
|
def initialize(self, ctx):
|
|
"""
|
|
First try to load torchscript else load eager mode state_dict based model
|
|
:param ctx: System properties
|
|
"""
|
|
|
|
properties = ctx.system_properties
|
|
self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
|
model_dir = properties.get("model_dir")
|
|
|
|
# Read model serialize/pt file
|
|
model_pt_path = os.path.join(model_dir, "bert.pth")
|
|
# Read model definition file
|
|
model_def_path = os.path.join(model_dir, "bert_train.py")
|
|
if not os.path.isfile(model_def_path):
|
|
raise RuntimeError("Missing the model definition file")
|
|
self.VOCAB_FILE = os.path.join(model_dir, "bert-base-uncased-vocab.txt")
|
|
if not os.path.isfile(self.VOCAB_FILE):
|
|
raise RuntimeError("Missing the vocab file")
|
|
|
|
self.class_mapping_file = os.path.join(model_dir, "index_to_name.json")
|
|
|
|
state_dict = torch.load(model_pt_path, map_location=self.device)
|
|
self.model = BertNewsClassifier()
|
|
self.model.load_state_dict(state_dict)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
logger.debug("Model file %s loaded successfully", model_pt_path)
|
|
self.initialized = True
|
|
|
|
def preprocess(self, data):
|
|
"""
|
|
Receives text in form of json and converts it into an encoding for the inference stage
|
|
:param data: Input to be passed through the layers for prediction
|
|
:return: output - preprocessed encoding
|
|
"""
|
|
|
|
text = data[0].get("data")
|
|
if text is None:
|
|
text = data[0].get("body")
|
|
|
|
self.text = text
|
|
self.tokenizer = BertTokenizer(self.VOCAB_FILE)
|
|
self.input_ids = torch.tensor(
|
|
[self.tokenizer.encode(self.text, add_special_tokens=True)]
|
|
).to(self.device)
|
|
return self.input_ids
|
|
|
|
def inference(self, input_ids):
|
|
"""
|
|
Predict the class for a review / sentence whether
|
|
it is belong to world / sports / business /sci-tech
|
|
:param encoding: Input encoding to be passed through the layers for prediction
|
|
:return: output - predicted output
|
|
"""
|
|
inputs = self.input_ids.to(self.device)
|
|
self.outputs = self.model.forward(inputs)
|
|
self.out = np.argmax(self.outputs.cpu().detach())
|
|
return [self.out.item()]
|
|
|
|
def postprocess(self, inference_output):
|
|
"""
|
|
Does postprocess after inference to be returned to user
|
|
:param inference_output: Output of inference
|
|
:return: output - Output after post processing
|
|
"""
|
|
if os.path.exists(self.class_mapping_file):
|
|
with open(self.class_mapping_file) as json_file:
|
|
data = json.load(json_file)
|
|
inference_output = json.dumps(data[str(inference_output[0])])
|
|
return [inference_output]
|
|
|
|
return inference_output
|
|
|
|
def add_attributions_to_visualizer(
|
|
self,
|
|
attributions,
|
|
tokens,
|
|
pred_prob,
|
|
pred_class,
|
|
true_class,
|
|
attr_class,
|
|
delta,
|
|
vis_data_records,
|
|
):
|
|
attributions = attributions.sum(dim=2).squeeze(0)
|
|
attributions = attributions / torch.norm(attributions)
|
|
attributions = attributions.cpu().detach().numpy()
|
|
|
|
# storing couple samples in an array for visualization purposes
|
|
vis_data_records.append(
|
|
visualization.VisualizationDataRecord(
|
|
attributions,
|
|
pred_prob,
|
|
pred_class,
|
|
true_class,
|
|
attr_class,
|
|
attributions.sum(),
|
|
tokens,
|
|
delta,
|
|
)
|
|
)
|
|
|
|
def score_func(self, o):
|
|
output = F.softmax(o, dim=1)
|
|
pre_pro = np.argmax(output.cpu().detach())
|
|
return pre_pro
|
|
|
|
def summarize_attributions(self, attributions):
|
|
"""Summarises the attribution across multiple runs
|
|
Args:
|
|
attributions ([list): attributions from the Integrated Gradients
|
|
Returns:
|
|
list : Returns the attributions after normalizing them.
|
|
"""
|
|
attributions = attributions.sum(dim=-1).squeeze(0)
|
|
attributions = attributions / torch.norm(attributions)
|
|
return attributions
|
|
|
|
def explain_handle(self, model_wraper, text, target=1):
|
|
"""Captum explanations handler
|
|
Args:
|
|
data_preprocess (Torch Tensor):
|
|
Preprocessed data to be used for captum
|
|
raw_data (list): The unprocessed data to get target from the request
|
|
Returns:
|
|
dict : A dictionary response with the explanations response.
|
|
"""
|
|
vis_data_records_base = []
|
|
model_wrapper = AGNewsmodelWrapper(self.model)
|
|
tokenizer = BertTokenizer(self.VOCAB_FILE)
|
|
model_wrapper.eval()
|
|
model_wrapper.zero_grad()
|
|
encoding = tokenizer.encode_plus(
|
|
self.text, return_attention_mask=True, return_tensors="pt", add_special_tokens=False
|
|
)
|
|
input_ids = encoding["input_ids"]
|
|
attention_mask = encoding["attention_mask"]
|
|
input_ids = input_ids.to(self.device)
|
|
attention_mask = attention_mask.to(self.device)
|
|
input_embedding_test = model_wrapper.model.bert_model.embeddings(input_ids)
|
|
preds = model_wrapper(input_embedding_test, attention_mask)
|
|
out = np.argmax(preds.cpu().detach(), axis=1)
|
|
out = out.item()
|
|
ig_1 = IntegratedGradients(model_wrapper)
|
|
attributions, delta = ig_1.attribute( # pylint: disable=no-member
|
|
input_embedding_test,
|
|
n_steps=500,
|
|
return_convergence_delta=True,
|
|
target=1,
|
|
)
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy().tolist())
|
|
feature_imp_dict = {}
|
|
feature_imp_dict["words"] = tokens
|
|
attributions_sum = self.summarize_attributions(attributions)
|
|
feature_imp_dict["importances"] = attributions_sum.tolist()
|
|
feature_imp_dict["delta"] = delta[0].tolist()
|
|
self.add_attributions_to_visualizer(
|
|
attributions, tokens, self.score_func(preds), out, 2, 1, delta, vis_data_records_base
|
|
)
|
|
return [feature_imp_dict]
|