pipelines/components/PyTorch/pytorch-kfp-components/tests/test_trainer.py

203 lines
7.1 KiB
Python

#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for trainer component."""
import os
import shutil
import sys
import tempfile
import pytest
import pytorch_lightning
from pytorch_kfp_components.components.trainer.component import Trainer
dirname, filename = os.path.split(os.path.abspath(__file__))
IRIS_DIR = os.path.join(dirname, "iris")
sys.path.insert(0, IRIS_DIR)
MODULE_FILE_ARGS = {"lr": 0.1}
TRAINER_ARGS = {"max_epochs": 5}
DATA_MODULE_ARGS = {"num_workers": 2}
# pylint:disable=redefined-outer-name
@pytest.fixture(scope="class")
def trainer_params():
trainer_params = {
"module_file": "iris_classification.py",
"data_module_file": "iris_data_module.py",
"module_file_args": MODULE_FILE_ARGS,
"data_module_args": DATA_MODULE_ARGS,
"trainer_args": TRAINER_ARGS,
}
return trainer_params
MANDATORY_ARGS = [
"module_file",
"data_module_file",
]
OPTIONAL_ARGS = ["module_file_args", "data_module_args", "trainer_args"]
DEFAULT_MODEL_NAME = "model_state_dict.pth"
DEFAULT_SAVE_PATH = f"/tmp/{DEFAULT_MODEL_NAME}"
def invoke_training(trainer_params): # pylint: disable=W0621
"""This function invokes the training process."""
trainer = Trainer(
module_file=trainer_params["module_file"],
data_module_file=trainer_params["data_module_file"],
module_file_args=trainer_params["module_file_args"],
trainer_args=trainer_params["trainer_args"],
data_module_args=trainer_params["data_module_args"],
)
return trainer
@pytest.mark.parametrize("mandatory_key", MANDATORY_ARGS)
def test_mandatory_keys_type_check(trainer_params, mandatory_key):
"""Tests the uncexpected 'type' of mandatory args.
Args:
mandatory_key : mandatory arguments for inivoking training
"""
test_input = ["input_path"]
trainer_params[mandatory_key] = test_input
expected_exception_msg = (
f"{mandatory_key} must be of type <class 'str'> "
f"but received as {type(test_input)}"
)
with pytest.raises(TypeError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
@pytest.mark.parametrize("optional_key", OPTIONAL_ARGS)
def test_optional_keys_type_check(trainer_params, optional_key):
"""Tests the unexpected 'type' of optional args.
Args:
optional_key: optional arguments for invoking training
"""
test_input = "test_input"
trainer_params[optional_key] = test_input
expected_exception_msg = (
f"{optional_key} must be of type <class 'dict'> "
f"but received as {type(test_input)}"
)
with pytest.raises(TypeError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
@pytest.mark.parametrize("input_key", MANDATORY_ARGS + ["module_file_args"])
def test_mandatory_params(trainer_params, input_key):
"""Test for empty mandatory arguments.
Args:
input_key: name of the mandatory arg for training
"""
trainer_params[input_key] = None
expected_exception_msg = (
f"{input_key} is not optional. "
f"Received value: {trainer_params[input_key]}"
)
with pytest.raises(ValueError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
def test_data_module_args_optional(trainer_params):
"""Test for empty optional argument : data module args"""
trainer_params["data_module_args"] = None
invoke_training(trainer_params=trainer_params)
assert os.path.exists(DEFAULT_SAVE_PATH)
os.remove(DEFAULT_SAVE_PATH)
def test_trainer_args_none(trainer_params):
"""Test for empty trainer specific arguments."""
trainer_params["trainer_args"] = None
expected_exception_msg = r"trainer_args must be a dict"
with pytest.raises(TypeError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
def test_training_success(trainer_params):
"""Test the training success case with all required args."""
trainer = invoke_training(trainer_params=trainer_params)
assert os.path.exists(DEFAULT_SAVE_PATH)
os.remove(DEFAULT_SAVE_PATH)
assert hasattr(trainer, "ptl_trainer")
assert isinstance(
trainer.ptl_trainer, pytorch_lightning.trainer.trainer.Trainer
)
def test_training_success_with_custom_model_name(trainer_params):
"""Test for successful training with custom model name."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer_params["module_file_args"]["model_name"] = "iris.pth"
invoke_training(trainer_params=trainer_params)
assert "iris.pth" in os.listdir(tmp_dir)
shutil.rmtree(tmp_dir)
trainer_params["module_file_args"].pop("checkpoint_dir")
trainer_params["module_file_args"].pop("model_name")
def test_training_failure_with_empty_module_file_args(trainer_params):
"""Test for successful training with empty module file args."""
trainer_params["module_file_args"] = {}
exception_msg = "module_file_args is not optional. Received value: {}"
with pytest.raises(ValueError, match=exception_msg):
invoke_training(trainer_params=trainer_params)
def test_training_success_with_empty_trainer_args(trainer_params):
"""Test for successful training with empty trainer args."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["max_epochs"] = 5
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer_params["trainer_args"] = {}
invoke_training(trainer_params=trainer_params)
assert DEFAULT_MODEL_NAME in os.listdir(tmp_dir)
shutil.rmtree(tmp_dir)
def test_training_success_with_empty_data_module_args(trainer_params):
"""Test for successful training with empty data module args."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer_params["data_module_args"] = None
invoke_training(trainer_params=trainer_params)
assert DEFAULT_MODEL_NAME in os.listdir(tmp_dir)
shutil.rmtree(tmp_dir)
#
def test_trainer_output(trainer_params):
"""Test for successful training with proper saving of training output."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer = invoke_training(trainer_params=trainer_params)
assert hasattr(trainer, "output_dict")
assert trainer.output_dict is not None
assert trainer.output_dict["model_save_path"] == os.path.join(
tmp_dir, DEFAULT_MODEL_NAME
)
assert isinstance(
trainer.output_dict["ptl_trainer"],
pytorch_lightning.trainer.trainer.Trainer
)