pipelines/samples/contrib/pytorch-samples/cifar10/cifar10_train.py

209 lines
6.7 KiB
Python

# !/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#pylint: disable=no-member,unused-argument,arguments-differ
"""Cifar10 training module."""
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from torch import nn
from torchvision import models
class CIFAR10Classifier(pl.LightningModule): #pylint: disable=too-many-ancestors,too-many-instance-attributes
"""Cifar10 model class."""
def __init__(self, **kwargs):
"""Initializes the network, optimizer and scheduler."""
super(CIFAR10Classifier, self).__init__() #pylint: disable=super-with-arguments
self.model_conv = models.resnet50(pretrained=True)
for param in self.model_conv.parameters():
param.requires_grad = False
num_ftrs = self.model_conv.fc.in_features
num_classes = 10
self.model_conv.fc = nn.Linear(num_ftrs, num_classes)
self.scheduler = None
self.optimizer = None
self.args = kwargs
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.test_acc = Accuracy()
self.preds = []
self.target = []
self.example_input_array = torch.rand((1, 3, 64, 64))
def forward(self, x_var):
"""Forward function."""
out = self.model_conv(x_var)
return out
def training_step(self, train_batch, batch_idx):
"""Training Step
Args:
train_batch : training batch
batch_idx : batch id number
Returns:
train accuracy
"""
if batch_idx == 0:
self.reference_image = (train_batch[0][0]).unsqueeze(0) #pylint: disable=attribute-defined-outside-init
# self.reference_image.resize((1,1,28,28))
print("\n\nREFERENCE IMAGE!!!")
print(self.reference_image.shape)
x_var, y_var = train_batch
output = self.forward(x_var)
_, y_hat = torch.max(output, dim=1)
loss = F.cross_entropy(output, y_var)
self.log("train_loss", loss)
self.train_acc(y_hat, y_var)
self.log("train_acc", self.train_acc.compute())
return {"loss": loss}
def test_step(self, test_batch, batch_idx):
"""Testing step
Args:
test_batch : test batch data
batch_idx : tests batch id
Returns:
test accuracy
"""
x_var, y_var = test_batch
output = self.forward(x_var)
_, y_hat = torch.max(output, dim=1)
loss = F.cross_entropy(output, y_var)
accelerator = self.args.get("accelerator", None)
if accelerator is not None:
self.log("test_loss", loss, sync_dist=True)
else:
self.log("test_loss", loss)
self.test_acc(y_hat, y_var)
self.preds += y_hat.tolist()
self.target += y_var.tolist()
self.log("test_acc", self.test_acc.compute())
return {"test_acc": self.test_acc.compute()}
def validation_step(self, val_batch, batch_idx):
"""Testing step.
Args:
val_batch : val batch data
batch_idx : val batch id
Returns:
validation accuracy
"""
x_var, y_var = val_batch
output = self.forward(x_var)
_, y_hat = torch.max(output, dim=1)
loss = F.cross_entropy(output, y_var)
accelerator = self.args.get("accelerator", None)
if accelerator is not None:
self.log("val_loss", loss, sync_dist=True)
else:
self.log("val_loss", loss)
self.val_acc(y_hat, y_var)
self.log("val_acc", self.val_acc.compute())
return {"val_step_loss": loss, "val_loss": loss}
def configure_optimizers(self):
"""Initializes the optimizer and learning rate scheduler.
Returns:
output - Initialized optimizer and scheduler
"""
self.optimizer = torch.optim.Adam(
self.parameters(),
lr=self.args.get("lr", 0.001),
weight_decay=self.args.get("weight_decay", 0),
eps=self.args.get("eps", 1e-8)
)
self.scheduler = {
"scheduler":
torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="min",
factor=0.2,
patience=3,
min_lr=1e-6,
verbose=True,
),
"monitor":
"val_loss",
}
return [self.optimizer], [self.scheduler]
def makegrid(self, output, numrows): #pylint: disable=no-self-use
"""Makes grids.
Args:
output : Tensor output
numrows : num of rows.
Returns:
c_array : gird array
"""
outer = torch.Tensor.cpu(output).detach()
plt.figure(figsize=(20, 5))
b_array = np.array([]).reshape(0, outer.shape[2])
c_array = np.array([]).reshape(numrows * outer.shape[2], 0)
i = 0
j = 0
while i < outer.shape[1]:
img = outer[0][i]
b_array = np.concatenate((img, b_array), axis=0)
j += 1
if j == numrows:
c_array = np.concatenate((c_array, b_array), axis=1)
b_array = np.array([]).reshape(0, outer.shape[2])
j = 0
i += 1
return c_array
def show_activations(self, x_var):
"""Showns activation
Args:
x_var: x variable
"""
# logging reference image
self.logger.experiment.add_image(
"input",
torch.Tensor.cpu(x_var[0][0]),
self.current_epoch,
dataformats="HW"
)
# logging layer 1 activations
out = self.model_conv.conv1(x_var)
c_grid = self.makegrid(out, 4)
self.logger.experiment.add_image(
"layer 1", c_grid, self.current_epoch, dataformats="HW"
)
def training_epoch_end(self, outputs):
"""Training epoch end.
Args:
outputs: outputs of train end
"""
self.show_activations(self.reference_image)