101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
#!/usr/bin/env/python3
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# pylint: disable=W0221
|
|
# pylint: disable=W0613
|
|
# pylint: disable=W0223
|
|
from argparse import ArgumentParser
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchmetrics import Accuracy
|
|
|
|
|
|
class IrisClassification(pl.LightningModule):
|
|
|
|
def __init__(self, **kwargs):
|
|
super(IrisClassification, self).__init__()
|
|
|
|
self.train_acc = Accuracy()
|
|
self.val_acc = Accuracy()
|
|
self.test_acc = Accuracy()
|
|
self.args = kwargs
|
|
|
|
self.fc1 = nn.Linear(4, 10)
|
|
self.fc2 = nn.Linear(10, 10)
|
|
self.fc3 = nn.Linear(10, 3)
|
|
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = F.relu(self.fc3(x))
|
|
return x
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
"""
|
|
Add model specific arguments like learning rate
|
|
|
|
:param parent_parser: Application specific parser
|
|
|
|
:return: Returns the augmented arugument parser
|
|
"""
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
parser.add_argument(
|
|
"--lr",
|
|
type=float,
|
|
default=0.01,
|
|
metavar="LR",
|
|
help="learning rate (default: 0.001)",
|
|
)
|
|
return parser
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.Adam(self.parameters(), self.args.get("lr", 0.01))
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self.forward(x)
|
|
_, y_hat = torch.max(logits, dim=1)
|
|
loss = self.cross_entropy_loss(logits, y)
|
|
self.train_acc(y_hat, y)
|
|
self.log(
|
|
"train_acc",
|
|
self.train_acc.compute(),
|
|
on_step=False,
|
|
on_epoch=True,
|
|
)
|
|
self.log("train_loss", loss)
|
|
return {"loss": loss}
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self.forward(x)
|
|
_, y_hat = torch.max(logits, dim=1)
|
|
loss = F.cross_entropy(logits, y)
|
|
self.val_acc(y_hat, y)
|
|
self.log("val_acc", self.val_acc.compute())
|
|
self.log("val_loss", loss, sync_dist=True)
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self.forward(x)
|
|
_, y_hat = torch.max(logits, dim=1)
|
|
self.test_acc(y_hat, y)
|
|
self.log("test_acc", self.test_acc.compute())
|