218 lines
8.0 KiB
Python
218 lines
8.0 KiB
Python
# !/usr/bin/env/python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# pylint: disable=arguments-differ
|
|
# pylint: disable=unused-argument
|
|
# pylint: disable=abstract-method
|
|
"""Bert Training Script."""
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torchmetrics import Accuracy
|
|
from sklearn.metrics import accuracy_score
|
|
from torch import nn
|
|
from transformers import AdamW, BertModel
|
|
|
|
|
|
class BertNewsClassifier(pl.LightningModule): #pylint: disable=too-many-ancestors,too-many-instance-attributes
|
|
"""Bert Model Class."""
|
|
|
|
def __init__(self, **kwargs):
|
|
"""Initializes the network, optimizer and scheduler."""
|
|
super(BertNewsClassifier, self).__init__() #pylint: disable=super-with-arguments
|
|
self.pre_trained_model_name = "bert-base-uncased" #pylint: disable=invalid-name
|
|
self.bert_model = BertModel.from_pretrained(self.pre_trained_model_name)
|
|
for param in self.bert_model.parameters():
|
|
param.requires_grad = False
|
|
self.drop = nn.Dropout(p=0.2)
|
|
# assigning labels
|
|
self.class_names = ["World", "Sports", "Business", "Sci/Tech"]
|
|
n_classes = len(self.class_names)
|
|
|
|
self.fc1 = nn.Linear(self.bert_model.config.hidden_size, 512)
|
|
self.out = nn.Linear(512, n_classes)
|
|
# self.bert_model.embedding = self.bert_model.embeddings
|
|
# self.embedding = self.bert_model.embeddings
|
|
|
|
self.scheduler = None
|
|
self.optimizer = None
|
|
self.args = kwargs
|
|
|
|
self.train_acc = Accuracy()
|
|
self.val_acc = Accuracy()
|
|
self.test_acc = Accuracy()
|
|
|
|
self.preds = []
|
|
self.target = []
|
|
|
|
def compute_bert_outputs( #pylint: disable=no-self-use
|
|
self, model_bert, embedding_input, attention_mask=None, head_mask=None
|
|
):
|
|
"""Computes Bert Outputs.
|
|
|
|
Args:
|
|
model_bert : the bert model
|
|
embedding_input : input for bert embeddings.
|
|
attention_mask : attention mask
|
|
head_mask : head mask
|
|
Returns:
|
|
output : the bert output
|
|
"""
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones( #pylint: disable=no-member
|
|
embedding_input.shape[0], embedding_input.shape[1]
|
|
).to(embedding_input)
|
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
extended_attention_mask = extended_attention_mask.to(
|
|
dtype=next(model_bert.parameters()).dtype
|
|
) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
if head_mask is not None:
|
|
if head_mask.dim() == 1:
|
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
|
|
-1
|
|
).unsqueeze(-1)
|
|
head_mask = head_mask.expand(
|
|
model_bert.config.num_hidden_layers, -1, -1, -1, -1
|
|
)
|
|
elif head_mask.dim() == 2:
|
|
head_mask = (
|
|
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
) # We can specify head_mask for each layer
|
|
head_mask = head_mask.to(
|
|
dtype=next(model_bert.parameters()).dtype
|
|
) # switch to fload if need + fp16 compatibility
|
|
else:
|
|
head_mask = [None] * model_bert.config.num_hidden_layers
|
|
|
|
encoder_outputs = model_bert.encoder(
|
|
embedding_input, extended_attention_mask, head_mask=head_mask
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = model_bert.pooler(sequence_output)
|
|
outputs = (
|
|
sequence_output,
|
|
pooled_output,
|
|
) + encoder_outputs[1:]
|
|
return outputs
|
|
|
|
def forward(self, input_ids, attention_mask=None):
|
|
""" Forward function.
|
|
Args:
|
|
input_ids: Input data
|
|
attention_maks: Attention mask value
|
|
|
|
Returns:
|
|
output - Type of news for the given news snippet
|
|
"""
|
|
embedding_input = self.bert_model.embeddings(input_ids)
|
|
outputs = self.compute_bert_outputs(
|
|
self.bert_model, embedding_input, attention_mask
|
|
)
|
|
pooled_output = outputs[1]
|
|
output = torch.tanh(self.fc1(pooled_output))
|
|
output = self.drop(output)
|
|
output = self.out(output)
|
|
return output
|
|
|
|
def training_step(self, train_batch, batch_idx):
|
|
"""Training the data as batches and returns training loss on each
|
|
batch.
|
|
|
|
Args:
|
|
train_batch Batch data
|
|
batch_idx: Batch indices
|
|
|
|
Returns:
|
|
output - Training loss
|
|
"""
|
|
input_ids = train_batch["input_ids"].to(self.device)
|
|
attention_mask = train_batch["attention_mask"].to(self.device)
|
|
targets = train_batch["targets"].to(self.device)
|
|
output = self.forward(input_ids, attention_mask)
|
|
_, y_hat = torch.max(output, dim=1) #pylint: disable=no-member
|
|
loss = F.cross_entropy(output, targets)
|
|
self.train_acc(y_hat, targets)
|
|
self.log("train_acc", self.train_acc.compute())
|
|
self.log("train_loss", loss)
|
|
return {"loss": loss, "acc": self.train_acc.compute()}
|
|
|
|
def test_step(self, test_batch, batch_idx):
|
|
"""Performs test and computes the accuracy of the model.
|
|
|
|
Args:
|
|
test_batch: Batch data
|
|
batch_idx: Batch indices
|
|
|
|
Returns:
|
|
output - Testing accuracy
|
|
"""
|
|
input_ids = test_batch["input_ids"].to(self.device)
|
|
attention_mask = test_batch["attention_mask"].to(self.device)
|
|
targets = test_batch["targets"].to(self.device)
|
|
output = self.forward(input_ids, attention_mask)
|
|
_, y_hat = torch.max(output, dim=1) #pylint: disable=no-member
|
|
test_acc = accuracy_score(y_hat.cpu(), targets.cpu())
|
|
self.test_acc(y_hat, targets)
|
|
self.preds += y_hat.tolist()
|
|
self.target += targets.tolist()
|
|
self.log("test_acc", self.test_acc.compute())
|
|
return {"test_acc": torch.tensor(test_acc)} #pylint: disable=no-member
|
|
|
|
def validation_step(self, val_batch, batch_idx):
|
|
"""Performs validation of data in batches.
|
|
|
|
Args:
|
|
val_batch: Batch data
|
|
batch_idx: Batch indices
|
|
|
|
Returns:
|
|
output - valid step loss
|
|
"""
|
|
|
|
input_ids = val_batch["input_ids"].to(self.device)
|
|
attention_mask = val_batch["attention_mask"].to(self.device)
|
|
targets = val_batch["targets"].to(self.device)
|
|
output = self.forward(input_ids, attention_mask)
|
|
_, y_hat = torch.max(output, dim=1) #pylint: disable=no-member
|
|
loss = F.cross_entropy(output, targets)
|
|
self.val_acc(y_hat, targets)
|
|
self.log("val_acc", self.val_acc.compute())
|
|
self.log("val_loss", loss, sync_dist=True)
|
|
return {"val_step_loss": loss, "acc": self.val_acc.compute()}
|
|
|
|
def configure_optimizers(self):
|
|
"""Initializes the optimizer and learning rate scheduler.
|
|
|
|
Returns:
|
|
output - Initialized optimizer and scheduler
|
|
"""
|
|
self.optimizer = AdamW(self.parameters(), lr=self.args.get("lr", 0.001))
|
|
self.scheduler = {
|
|
"scheduler":
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
self.optimizer,
|
|
mode="min",
|
|
factor=0.2,
|
|
patience=2,
|
|
min_lr=1e-6,
|
|
verbose=True,
|
|
),
|
|
"monitor":
|
|
"val_loss",
|
|
}
|
|
return [self.optimizer], [self.scheduler]
|