pipelines/components/PyTorch/pytorch-kfp-components/tests/iris/iris_classification.py

101 lines
3.0 KiB
Python

#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=W0221
# pylint: disable=W0613
# pylint: disable=W0223
from argparse import ArgumentParser
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
class IrisClassification(pl.LightningModule):
def __init__(self, **kwargs):
super(IrisClassification, self).__init__()
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.test_acc = Accuracy()
self.args = kwargs
self.fc1 = nn.Linear(4, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 3)
self.cross_entropy_loss = nn.CrossEntropyLoss()
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
return x
@staticmethod
def add_model_specific_args(parent_parser):
"""
Add model specific arguments like learning rate
:param parent_parser: Application specific parser
:return: Returns the augmented arugument parser
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.001)",
)
return parser
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), self.args.get("lr", 0.01))
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
_, y_hat = torch.max(logits, dim=1)
loss = self.cross_entropy_loss(logits, y)
self.train_acc(y_hat, y)
self.log(
"train_acc",
self.train_acc.compute(),
on_step=False,
on_epoch=True,
)
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
_, y_hat = torch.max(logits, dim=1)
loss = F.cross_entropy(logits, y)
self.val_acc(y_hat, y)
self.log("val_acc", self.val_acc.compute())
self.log("val_loss", loss, sync_dist=True)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
_, y_hat = torch.max(logits, dim=1)
self.test_acc(y_hat, y)
self.log("test_acc", self.test_acc.compute())