pipelines/samples/contrib/pytorch-samples/bert/bert_train.py

218 lines
8.0 KiB
Python

# !/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=arguments-differ
# pylint: disable=unused-argument
# pylint: disable=abstract-method
"""Bert Training Script."""
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from sklearn.metrics import accuracy_score
from torch import nn
from transformers import AdamW, BertModel
class BertNewsClassifier(pl.LightningModule): #pylint: disable=too-many-ancestors,too-many-instance-attributes
"""Bert Model Class."""
def __init__(self, **kwargs):
"""Initializes the network, optimizer and scheduler."""
super(BertNewsClassifier, self).__init__() #pylint: disable=super-with-arguments
self.pre_trained_model_name = "bert-base-uncased" #pylint: disable=invalid-name
self.bert_model = BertModel.from_pretrained(self.pre_trained_model_name)
for param in self.bert_model.parameters():
param.requires_grad = False
self.drop = nn.Dropout(p=0.2)
# assigning labels
self.class_names = ["World", "Sports", "Business", "Sci/Tech"]
n_classes = len(self.class_names)
self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512)
self.out = nn.Linear(512, n_classes)
# self.bert_model.embedding = self.bert_model.embeddings
# self.embedding = self.bert_model.embeddings
self.scheduler = None
self.optimizer = None
self.args = kwargs
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.test_acc = Accuracy()
self.preds = []
self.target = []
def compute_bert_outputs( #pylint: disable=no-self-use
self, model_bert, embedding_input, attention_mask=None, head_mask=None
):
"""Computes Bert Outputs.
Args:
model_bert : the bert model
embedding_input : input for bert embeddings.
attention_mask : attention mask
head_mask : head mask
Returns:
output : the bert output
"""
if attention_mask is None:
attention_mask = torch.ones( #pylint: disable=no-member
embedding_input.shape[0], embedding_input.shape[1]
).to(embedding_input)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(
dtype=next(model_bert.parameters()).dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
-1
).unsqueeze(-1)
head_mask = head_mask.expand(
model_bert.config.num_hidden_layers, -1, -1, -1, -1
)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(model_bert.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * model_bert.config.num_hidden_layers
encoder_outputs = model_bert.encoder(
embedding_input, extended_attention_mask, head_mask=head_mask
)
sequence_output = encoder_outputs[0]
pooled_output = model_bert.pooler(sequence_output)
outputs = (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return outputs
def forward(self, input_ids, attention_mask=None):
""" Forward function.
Args:
input_ids: Input data
attention_maks: Attention mask value
Returns:
output - Type of news for the given news snippet
"""
embedding_input = self.bert_model.embeddings(input_ids)
outputs = self.compute_bert_outputs(
self.bert_model, embedding_input, attention_mask
)
pooled_output = outputs[1]
output = torch.tanh(self.fc1(pooled_output))
output = self.drop(output)
output = self.out(output)
return output
def training_step(self, train_batch, batch_idx):
"""Training the data as batches and returns training loss on each
batch.
Args:
train_batch Batch data
batch_idx: Batch indices
Returns:
output - Training loss
"""
input_ids = train_batch["input_ids"].to(self.device)
attention_mask = train_batch["attention_mask"].to(self.device)
targets = train_batch["targets"].to(self.device)
output = self.forward(input_ids, attention_mask)
_, y_hat = torch.max(output, dim=1) #pylint: disable=no-member
loss = F.cross_entropy(output, targets)
self.train_acc(y_hat, targets)
self.log("train_acc", self.train_acc.compute())
self.log("train_loss", loss)
return {"loss": loss, "acc": self.train_acc.compute()}
def test_step(self, test_batch, batch_idx):
"""Performs test and computes the accuracy of the model.
Args:
test_batch: Batch data
batch_idx: Batch indices
Returns:
output - Testing accuracy
"""
input_ids = test_batch["input_ids"].to(self.device)
attention_mask = test_batch["attention_mask"].to(self.device)
targets = test_batch["targets"].to(self.device)
output = self.forward(input_ids, attention_mask)
_, y_hat = torch.max(output, dim=1) #pylint: disable=no-member
test_acc = accuracy_score(y_hat.cpu(), targets.cpu())
self.test_acc(y_hat, targets)
self.preds += y_hat.tolist()
self.target += targets.tolist()
self.log("test_acc", self.test_acc.compute())
return {"test_acc": torch.tensor(test_acc)} #pylint: disable=no-member
def validation_step(self, val_batch, batch_idx):
"""Performs validation of data in batches.
Args:
val_batch: Batch data
batch_idx: Batch indices
Returns:
output - valid step loss
"""
input_ids = val_batch["input_ids"].to(self.device)
attention_mask = val_batch["attention_mask"].to(self.device)
targets = val_batch["targets"].to(self.device)
output = self.forward(input_ids, attention_mask)
_, y_hat = torch.max(output, dim=1) #pylint: disable=no-member
loss = F.cross_entropy(output, targets)
self.val_acc(y_hat, targets)
self.log("val_acc", self.val_acc.compute())
self.log("val_loss", loss, sync_dist=True)
return {"val_step_loss": loss, "acc": self.val_acc.compute()}
def configure_optimizers(self):
"""Initializes the optimizer and learning rate scheduler.
Returns:
output - Initialized optimizer and scheduler
"""
self.optimizer = AdamW(self.parameters(), lr=self.args.get("lr", 0.001))
self.scheduler = {
"scheduler":
torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="min",
factor=0.2,
patience=2,
min_lr=1e-6,
verbose=True,
),
"monitor":
"val_loss",
}
return [self.optimizer], [self.scheduler]