pipelines/samples/contrib/pytorch-samples/bert/wrapper.py

102 lines
3.6 KiB
Python

# !/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=arguments-differ
# pylint: disable=unused-argument
# pylint: disable=abstract-method
"""Bert Wrapper."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AGNewsmodelWrapper(nn.Module):
"""Warapper Class."""
def __init__(self, model):
super( # pylint: disable=super-with-arguments
AGNewsmodelWrapper, self
).__init__()
self.model = model
def compute_bert_outputs( # pylint: disable=no-self-use
self, model_bert, embedding_input, attention_mask=None, head_mask=None
):
"""Computes Bert Outputs.
Args:
model_bert : the bert model
embedding_input : input for bert embeddings.
attention_mask : attention mask
head_mask : head mask
Returns:
output : the bert output
"""
if attention_mask is None:
attention_mask = torch.ones( # pylint: disable=no-member
embedding_input.shape[0], embedding_input.shape[1]
).to(embedding_input)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(
dtype=next(model_bert.parameters()).dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = (
head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).
unsqueeze(-1)
)
head_mask = head_mask.expand(
model_bert.config.num_hidden_layers, -1, -1, -1, -1
)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(model_bert.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * model_bert.config.num_hidden_layers
encoder_outputs = model_bert.encoder(
embedding_input, extended_attention_mask, head_mask=head_mask
)
sequence_output = encoder_outputs[0]
pooled_output = model_bert.pooler(sequence_output)
outputs = (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return outputs
def forward(self, embeddings, attention_mask=None):
"""Forward function.
Args:
embeddings : bert embeddings.
attention_mask: Attention mask value
"""
outputs = self.compute_bert_outputs(
self.model.bert_model, embeddings, attention_mask
)
pooled_output = outputs[1]
output = F.relu(self.model.fc1(pooled_output))
output = self.model.drop(output)
output = self.model.out(output)
return output