209 lines
6.7 KiB
Python
209 lines
6.7 KiB
Python
# !/usr/bin/env/python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#pylint: disable=no-member,unused-argument,arguments-differ
|
|
"""Cifar10 training module."""
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torchmetrics import Accuracy
|
|
from torch import nn
|
|
from torchvision import models
|
|
|
|
|
|
class CIFAR10Classifier(pl.LightningModule): #pylint: disable=too-many-ancestors,too-many-instance-attributes
|
|
"""Cifar10 model class."""
|
|
|
|
def __init__(self, **kwargs):
|
|
"""Initializes the network, optimizer and scheduler."""
|
|
super(CIFAR10Classifier, self).__init__() #pylint: disable=super-with-arguments
|
|
self.model_conv = models.resnet50(pretrained=True)
|
|
for param in self.model_conv.parameters():
|
|
param.requires_grad = False
|
|
num_ftrs = self.model_conv.fc.in_features
|
|
num_classes = 10
|
|
self.model_conv.fc = nn.Linear(num_ftrs, num_classes)
|
|
|
|
self.scheduler = None
|
|
self.optimizer = None
|
|
self.args = kwargs
|
|
|
|
self.train_acc = Accuracy()
|
|
self.val_acc = Accuracy()
|
|
self.test_acc = Accuracy()
|
|
|
|
self.preds = []
|
|
self.target = []
|
|
self.example_input_array = torch.rand((1, 3, 64, 64))
|
|
|
|
def forward(self, x_var):
|
|
"""Forward function."""
|
|
out = self.model_conv(x_var)
|
|
return out
|
|
|
|
def training_step(self, train_batch, batch_idx):
|
|
"""Training Step
|
|
Args:
|
|
train_batch : training batch
|
|
batch_idx : batch id number
|
|
Returns:
|
|
train accuracy
|
|
"""
|
|
if batch_idx == 0:
|
|
self.reference_image = (train_batch[0][0]).unsqueeze(0) #pylint: disable=attribute-defined-outside-init
|
|
# self.reference_image.resize((1,1,28,28))
|
|
print("\n\nREFERENCE IMAGE!!!")
|
|
print(self.reference_image.shape)
|
|
x_var, y_var = train_batch
|
|
output = self.forward(x_var)
|
|
_, y_hat = torch.max(output, dim=1)
|
|
loss = F.cross_entropy(output, y_var)
|
|
self.log("train_loss", loss)
|
|
self.train_acc(y_hat, y_var)
|
|
self.log("train_acc", self.train_acc.compute())
|
|
return {"loss": loss}
|
|
|
|
def test_step(self, test_batch, batch_idx):
|
|
"""Testing step
|
|
Args:
|
|
test_batch : test batch data
|
|
batch_idx : tests batch id
|
|
Returns:
|
|
test accuracy
|
|
"""
|
|
|
|
x_var, y_var = test_batch
|
|
output = self.forward(x_var)
|
|
_, y_hat = torch.max(output, dim=1)
|
|
loss = F.cross_entropy(output, y_var)
|
|
accelerator = self.args.get("accelerator", None)
|
|
if accelerator is not None:
|
|
self.log("test_loss", loss, sync_dist=True)
|
|
else:
|
|
self.log("test_loss", loss)
|
|
self.test_acc(y_hat, y_var)
|
|
self.preds += y_hat.tolist()
|
|
self.target += y_var.tolist()
|
|
|
|
self.log("test_acc", self.test_acc.compute())
|
|
return {"test_acc": self.test_acc.compute()}
|
|
|
|
def validation_step(self, val_batch, batch_idx):
|
|
"""Testing step.
|
|
|
|
Args:
|
|
val_batch : val batch data
|
|
batch_idx : val batch id
|
|
Returns:
|
|
validation accuracy
|
|
"""
|
|
|
|
x_var, y_var = val_batch
|
|
output = self.forward(x_var)
|
|
_, y_hat = torch.max(output, dim=1)
|
|
loss = F.cross_entropy(output, y_var)
|
|
accelerator = self.args.get("accelerator", None)
|
|
if accelerator is not None:
|
|
self.log("val_loss", loss, sync_dist=True)
|
|
else:
|
|
self.log("val_loss", loss)
|
|
self.val_acc(y_hat, y_var)
|
|
self.log("val_acc", self.val_acc.compute())
|
|
return {"val_step_loss": loss, "val_loss": loss}
|
|
|
|
def configure_optimizers(self):
|
|
"""Initializes the optimizer and learning rate scheduler.
|
|
|
|
Returns:
|
|
output - Initialized optimizer and scheduler
|
|
"""
|
|
self.optimizer = torch.optim.Adam(
|
|
self.parameters(),
|
|
lr=self.args.get("lr", 0.001),
|
|
weight_decay=self.args.get("weight_decay", 0),
|
|
eps=self.args.get("eps", 1e-8)
|
|
)
|
|
self.scheduler = {
|
|
"scheduler":
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
self.optimizer,
|
|
mode="min",
|
|
factor=0.2,
|
|
patience=3,
|
|
min_lr=1e-6,
|
|
verbose=True,
|
|
),
|
|
"monitor":
|
|
"val_loss",
|
|
}
|
|
return [self.optimizer], [self.scheduler]
|
|
|
|
def makegrid(self, output, numrows): #pylint: disable=no-self-use
|
|
"""Makes grids.
|
|
|
|
Args:
|
|
output : Tensor output
|
|
numrows : num of rows.
|
|
Returns:
|
|
c_array : gird array
|
|
"""
|
|
outer = torch.Tensor.cpu(output).detach()
|
|
plt.figure(figsize=(20, 5))
|
|
b_array = np.array([]).reshape(0, outer.shape[2])
|
|
c_array = np.array([]).reshape(numrows * outer.shape[2], 0)
|
|
i = 0
|
|
j = 0
|
|
while i < outer.shape[1]:
|
|
img = outer[0][i]
|
|
b_array = np.concatenate((img, b_array), axis=0)
|
|
j += 1
|
|
if j == numrows:
|
|
c_array = np.concatenate((c_array, b_array), axis=1)
|
|
b_array = np.array([]).reshape(0, outer.shape[2])
|
|
j = 0
|
|
|
|
i += 1
|
|
return c_array
|
|
|
|
def show_activations(self, x_var):
|
|
"""Showns activation
|
|
Args:
|
|
x_var: x variable
|
|
"""
|
|
|
|
# logging reference image
|
|
self.logger.experiment.add_image(
|
|
"input",
|
|
torch.Tensor.cpu(x_var[0][0]),
|
|
self.current_epoch,
|
|
dataformats="HW"
|
|
)
|
|
|
|
# logging layer 1 activations
|
|
out = self.model_conv.conv1(x_var)
|
|
c_grid = self.makegrid(out, 4)
|
|
self.logger.experiment.add_image(
|
|
"layer 1", c_grid, self.current_epoch, dataformats="HW"
|
|
)
|
|
|
|
def training_epoch_end(self, outputs):
|
|
"""Training epoch end.
|
|
|
|
Args:
|
|
outputs: outputs of train end
|
|
"""
|
|
self.show_activations(self.reference_image)
|
|
|