pipelines/samples/contrib/pytorch-samples/bert/bert_handler.py

208 lines
7.7 KiB
Python

# !/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-self-use,too-many-arguments,unused-argument,not-callable,no-member,attribute-defined-outside-init
""" Bert Custom Handler."""
from captum.attr import IntegratedGradients
import json
import logging
import os
import numpy as np
import torch
from transformers import BertTokenizer
from ts.torch_handler.base_handler import BaseHandler
from captum.attr import visualization
import torch.nn.functional as F
from bert_train import BertNewsClassifier
from wrapper import AGNewsmodelWrapper
logger = logging.getLogger(__name__)
class NewsClassifierHandler(BaseHandler):
"""
NewsClassifierHandler class. This handler takes a review / sentence
and returns the label as either world / sports / business /sci-tech
"""
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.class_mapping_file = None
self.VOCAB_FILE = None
def initialize(self, ctx):
"""
First try to load torchscript else load eager mode state_dict based model
:param ctx: System properties
"""
properties = ctx.system_properties
self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_dir = properties.get("model_dir")
# Read model serialize/pt file
model_pt_path = os.path.join(model_dir, "bert.pth")
# Read model definition file
model_def_path = os.path.join(model_dir, "bert_train.py")
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model definition file")
self.VOCAB_FILE = os.path.join(model_dir, "bert-base-uncased-vocab.txt")
if not os.path.isfile(self.VOCAB_FILE):
raise RuntimeError("Missing the vocab file")
self.class_mapping_file = os.path.join(model_dir, "index_to_name.json")
state_dict = torch.load(model_pt_path, map_location=self.device)
self.model = BertNewsClassifier()
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
logger.debug("Model file %s loaded successfully", model_pt_path)
self.initialized = True
def preprocess(self, data):
"""
Receives text in form of json and converts it into an encoding for the inference stage
:param data: Input to be passed through the layers for prediction
:return: output - preprocessed encoding
"""
text = data[0].get("data")
if text is None:
text = data[0].get("body")
self.text = text
self.tokenizer = BertTokenizer(self.VOCAB_FILE)
self.input_ids = torch.tensor(
[self.tokenizer.encode(self.text, add_special_tokens=True)]
).to(self.device)
return self.input_ids
def inference(self, input_ids):
"""
Predict the class for a review / sentence whether
it is belong to world / sports / business /sci-tech
:param encoding: Input encoding to be passed through the layers for prediction
:return: output - predicted output
"""
inputs = self.input_ids.to(self.device)
self.outputs = self.model.forward(inputs)
self.out = np.argmax(self.outputs.cpu().detach())
return [self.out.item()]
def postprocess(self, inference_output):
"""
Does postprocess after inference to be returned to user
:param inference_output: Output of inference
:return: output - Output after post processing
"""
if os.path.exists(self.class_mapping_file):
with open(self.class_mapping_file) as json_file:
data = json.load(json_file)
inference_output = json.dumps(data[str(inference_output[0])])
return [inference_output]
return inference_output
def add_attributions_to_visualizer(
self,
attributions,
tokens,
pred_prob,
pred_class,
true_class,
attr_class,
delta,
vis_data_records,
):
attributions = attributions.sum(dim=2).squeeze(0)
attributions = attributions / torch.norm(attributions)
attributions = attributions.cpu().detach().numpy()
# storing couple samples in an array for visualization purposes
vis_data_records.append(
visualization.VisualizationDataRecord(
attributions,
pred_prob,
pred_class,
true_class,
attr_class,
attributions.sum(),
tokens,
delta,
)
)
def score_func(self, o):
output = F.softmax(o, dim=1)
pre_pro = np.argmax(output.cpu().detach())
return pre_pro
def summarize_attributions(self, attributions):
"""Summarises the attribution across multiple runs
Args:
attributions ([list): attributions from the Integrated Gradients
Returns:
list : Returns the attributions after normalizing them.
"""
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def explain_handle(self, model_wraper, text, target=1):
"""Captum explanations handler
Args:
data_preprocess (Torch Tensor):
Preprocessed data to be used for captum
raw_data (list): The unprocessed data to get target from the request
Returns:
dict : A dictionary response with the explanations response.
"""
vis_data_records_base = []
model_wrapper = AGNewsmodelWrapper(self.model)
tokenizer = BertTokenizer(self.VOCAB_FILE)
model_wrapper.eval()
model_wrapper.zero_grad()
encoding = tokenizer.encode_plus(
self.text, return_attention_mask=True, return_tensors="pt", add_special_tokens=False
)
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
input_embedding_test = model_wrapper.model.bert_model.embeddings(input_ids)
preds = model_wrapper(input_embedding_test, attention_mask)
out = np.argmax(preds.cpu().detach(), axis=1)
out = out.item()
ig_1 = IntegratedGradients(model_wrapper)
attributions, delta = ig_1.attribute( # pylint: disable=no-member
input_embedding_test,
n_steps=500,
return_convergence_delta=True,
target=1,
)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy().tolist())
feature_imp_dict = {}
feature_imp_dict["words"] = tokens
attributions_sum = self.summarize_attributions(attributions)
feature_imp_dict["importances"] = attributions_sum.tolist()
feature_imp_dict["delta"] = delta[0].tolist()
self.add_attributions_to_visualizer(
attributions, tokens, self.score_func(preds), out, 2, 1, delta, vis_data_records_base
)
return [feature_imp_dict]