feat(components): Addressing Review comments on Trainer component for PyTorch - KFP (#5814)

* fix url and keywords in setup.py

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* Minio UT fix

Signed-off-by: ankan94 <ankan@ideas2it.com>

* Base component metaclass modification

Signed-off-by: ankan94 <ankan@ideas2it.com>

* Adding fast failing for trainer component.

Signed-off-by: ankan94 <ankan@ideas2it.com>

* fixing minio test

Signed-off-by: ankan94 <ankan@ideas2it.com>

* Fixing lint issues

Signed-off-by: ankan94 <ankan@ideas2it.com>

Co-authored-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>
This commit is contained in:
ankan94 2021-06-17 22:12:28 +05:30 committed by GitHub
parent 60ed8e45f7
commit 72f7464d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 25 additions and 27 deletions

View File

@ -15,11 +15,10 @@
"""Pipeline Base component class."""
import abc
from six import with_metaclass
from pytorch_kfp_components.types import standard_component_specs
class BaseComponent(with_metaclass(abc.ABCMeta, object)): # pylint: disable=R0903
class BaseComponent(metaclass=abc.ABCMeta): # pylint: disable=R0903
"""Pipeline Base component class."""
def __init__(self):

View File

@ -161,7 +161,7 @@ class Executor(BaseExecutor):
print("Running Archiver cmd: ", archiver_cmd)
proc = subprocess.Popen(
proc = subprocess.Popen( #pylint: disable=consider-using-with
archiver_cmd,
shell=True,
stdout=subprocess.PIPE,

View File

@ -13,10 +13,10 @@
# limitations under the License.
"""Minio Executor Module."""
import os
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
from pytorch_kfp_components.types import standard_component_specs
import urllib3
from minio import Minio #pylint: disable=no-name-in-module
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
from pytorch_kfp_components.types import standard_component_specs
class Executor(BaseExecutor):

View File

@ -19,7 +19,7 @@ from pytorch_kfp_components.components.base.base_component import BaseComponent
from pytorch_kfp_components.types import standard_component_specs
class Trainer(BaseComponent):
class Trainer(BaseComponent): #pylint: disable=too-few-public-methods
"""Initializes the Trainer class."""
def __init__( # pylint: disable=R0913

View File

@ -30,7 +30,7 @@ class Executor(GenericExecutor):
def __init__(self): # pylint:disable=useless-super-delegation
super().__init__()
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict): #pylint: disable=too-many-locals
"""This function of the Executor invokes the PyTorch Lightning training
loop.
@ -63,7 +63,7 @@ class Executor(GenericExecutor):
trainer_args,
module_file_args,
data_module_args,
) = self._GetFnArgs(
) = self._get_fn_args(
input_dict=input_dict,
output_dict=output_dict,
execution_properties=exec_properties,
@ -75,6 +75,11 @@ class Executor(GenericExecutor):
) = self.derive_model_and_data_module_class(
module_file=module_file, data_module_file=data_module_file
)
if not data_module_class :
raise NotImplementedError(
"Data module class is mandatory. "
"User defined training module is yet to be supported."
)
if data_module_class:
data_module = data_module_class(
**data_module_args if data_module_args else {}
@ -93,8 +98,8 @@ class Executor(GenericExecutor):
parser = Namespace(**module_file_args)
trainer = pl.Trainer.from_argparse_args(parser)
trainer.fit(model, data_module)
trainer.test()
trainer.fit(model, data_module) #pylint: disable=no-member
trainer.test() #pylint: disable=no-member
if "checkpoint_dir" in module_file_args:
model_save_path = module_file_args["checkpoint_dir"]
@ -114,9 +119,3 @@ class Executor(GenericExecutor):
output_dict[standard_component_specs.TRAINER_MODEL_SAVE_PATH
] = model_save_path
output_dict[standard_component_specs.PTL_TRAINER_OBJ] = trainer
else:
raise NotImplementedError(
"Data module class is mandatory. "
"User defined training module is yet to be supported."
)

View File

@ -24,11 +24,11 @@ class GenericExecutor(BaseExecutor):
"""Generic Executor Class that does nothing."""
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
# TODO: Code to train pretrained model
#TODO: Code to train pretrained model #pylint: disable=fixme
pass
def _GetFnArgs(
self, input_dict: dict, output_dict: dict, execution_properties: dict
def _get_fn_args( #pylint: disable=no-self-use
self, input_dict: dict, output_dict: dict, execution_properties: dict #pylint: disable=unused-argument
):
"""Gets the input/output/execution properties from the dictionary.
@ -68,7 +68,7 @@ class GenericExecutor(BaseExecutor):
data_module_args,
)
def derive_model_and_data_module_class(
def derive_model_and_data_module_class( #pylint: disable=no-self-use
self, module_file: str, data_module_file: str
):
"""Derives the model file and data modul file.

View File

@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for defining standard specifications and validation of parameter
type."""
#pylint: disable=duplicate-code
TRAINER_MODULE_FILE = "module_file"
TRAINER_DATA_MODULE_FILE = "data_module_file"
TRAINER_DATA_MODULE_ARGS = "data_module_args"
@ -112,7 +111,7 @@ class MarGenerationSpec: # pylint: disable=R0903
}
class VisualizationSpec:
class VisualizationSpec: #pylint: disable=too-few-public-methods
"""Visualization Specification class.
For validating the parameter 'type'
"""
@ -142,7 +141,7 @@ class VisualizationSpec:
}
class MinIoSpec:
class MinIoSpec: #pylint: disable=too-few-public-methods
"""MinIO Specification class.
For validating the parameter 'type'
"""

View File

@ -69,7 +69,7 @@ if __name__ == "__main__":
name="pytorch-kfp-components",
version=version,
description="PyTorch Kubeflow Pipeline",
url="https://github.com/kubeflow/pipelines/tree/master/components",
url="https://github.com/kubeflow/pipelines/tree/master/components/PyTorch/pytorch-kfp-components/",
author="The PyTorch Kubeflow Pipeline Components authors",
author_email="pytorch-kfp-components@fb.com",
license="Apache License 2.0",
@ -79,7 +79,8 @@ if __name__ == "__main__":
install_requires=make_required_install_packages(),
dependency_links=make_dependency_links(),
keywords=[
"Kubeflow",
"Kubeflow Pipelines",
"KFP",
"ML workflow",
"PyTorch",
],

View File

@ -102,7 +102,7 @@ def test_unreachable_endpoint(minio_inputs):
"""Testing unreachable minio endpoint with invalid minio creds."""
os.environ["MINIO_ACCESS_KEY"] = "dummy"
os.environ["MINIO_SECRET_KEY"] = "dummy"
with pytest.raises(Exception, match="Max retries exceeded with url*"):
with pytest.raises(Exception, match="Max retries exceeded with url: "):
upload_to_minio(minio_inputs)