feat(components): Addressing Review comments on Trainer component for PyTorch - KFP (#5814)
* fix url and keywords in setup.py Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com> * Minio UT fix Signed-off-by: ankan94 <ankan@ideas2it.com> * Base component metaclass modification Signed-off-by: ankan94 <ankan@ideas2it.com> * Adding fast failing for trainer component. Signed-off-by: ankan94 <ankan@ideas2it.com> * fixing minio test Signed-off-by: ankan94 <ankan@ideas2it.com> * Fixing lint issues Signed-off-by: ankan94 <ankan@ideas2it.com> Co-authored-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>
This commit is contained in:
parent
60ed8e45f7
commit
72f7464d01
|
|
@ -15,11 +15,10 @@
|
|||
"""Pipeline Base component class."""
|
||||
|
||||
import abc
|
||||
from six import with_metaclass
|
||||
from pytorch_kfp_components.types import standard_component_specs
|
||||
|
||||
|
||||
class BaseComponent(with_metaclass(abc.ABCMeta, object)): # pylint: disable=R0903
|
||||
class BaseComponent(metaclass=abc.ABCMeta): # pylint: disable=R0903
|
||||
"""Pipeline Base component class."""
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class Executor(BaseExecutor):
|
|||
|
||||
print("Running Archiver cmd: ", archiver_cmd)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
proc = subprocess.Popen( #pylint: disable=consider-using-with
|
||||
archiver_cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@
|
|||
# limitations under the License.
|
||||
"""Minio Executor Module."""
|
||||
import os
|
||||
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
|
||||
from pytorch_kfp_components.types import standard_component_specs
|
||||
import urllib3
|
||||
from minio import Minio #pylint: disable=no-name-in-module
|
||||
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
|
||||
from pytorch_kfp_components.types import standard_component_specs
|
||||
|
||||
|
||||
class Executor(BaseExecutor):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from pytorch_kfp_components.components.base.base_component import BaseComponent
|
|||
from pytorch_kfp_components.types import standard_component_specs
|
||||
|
||||
|
||||
class Trainer(BaseComponent):
|
||||
class Trainer(BaseComponent): #pylint: disable=too-few-public-methods
|
||||
"""Initializes the Trainer class."""
|
||||
|
||||
def __init__( # pylint: disable=R0913
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class Executor(GenericExecutor):
|
|||
def __init__(self): # pylint:disable=useless-super-delegation
|
||||
super().__init__()
|
||||
|
||||
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
|
||||
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict): #pylint: disable=too-many-locals
|
||||
"""This function of the Executor invokes the PyTorch Lightning training
|
||||
loop.
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ class Executor(GenericExecutor):
|
|||
trainer_args,
|
||||
module_file_args,
|
||||
data_module_args,
|
||||
) = self._GetFnArgs(
|
||||
) = self._get_fn_args(
|
||||
input_dict=input_dict,
|
||||
output_dict=output_dict,
|
||||
execution_properties=exec_properties,
|
||||
|
|
@ -75,6 +75,11 @@ class Executor(GenericExecutor):
|
|||
) = self.derive_model_and_data_module_class(
|
||||
module_file=module_file, data_module_file=data_module_file
|
||||
)
|
||||
if not data_module_class :
|
||||
raise NotImplementedError(
|
||||
"Data module class is mandatory. "
|
||||
"User defined training module is yet to be supported."
|
||||
)
|
||||
if data_module_class:
|
||||
data_module = data_module_class(
|
||||
**data_module_args if data_module_args else {}
|
||||
|
|
@ -93,8 +98,8 @@ class Executor(GenericExecutor):
|
|||
parser = Namespace(**module_file_args)
|
||||
trainer = pl.Trainer.from_argparse_args(parser)
|
||||
|
||||
trainer.fit(model, data_module)
|
||||
trainer.test()
|
||||
trainer.fit(model, data_module) #pylint: disable=no-member
|
||||
trainer.test() #pylint: disable=no-member
|
||||
|
||||
if "checkpoint_dir" in module_file_args:
|
||||
model_save_path = module_file_args["checkpoint_dir"]
|
||||
|
|
@ -114,9 +119,3 @@ class Executor(GenericExecutor):
|
|||
output_dict[standard_component_specs.TRAINER_MODEL_SAVE_PATH
|
||||
] = model_save_path
|
||||
output_dict[standard_component_specs.PTL_TRAINER_OBJ] = trainer
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Data module class is mandatory. "
|
||||
"User defined training module is yet to be supported."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,11 +24,11 @@ class GenericExecutor(BaseExecutor):
|
|||
"""Generic Executor Class that does nothing."""
|
||||
|
||||
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
|
||||
# TODO: Code to train pretrained model
|
||||
#TODO: Code to train pretrained model #pylint: disable=fixme
|
||||
pass
|
||||
|
||||
def _GetFnArgs(
|
||||
self, input_dict: dict, output_dict: dict, execution_properties: dict
|
||||
def _get_fn_args( #pylint: disable=no-self-use
|
||||
self, input_dict: dict, output_dict: dict, execution_properties: dict #pylint: disable=unused-argument
|
||||
):
|
||||
"""Gets the input/output/execution properties from the dictionary.
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ class GenericExecutor(BaseExecutor):
|
|||
data_module_args,
|
||||
)
|
||||
|
||||
def derive_model_and_data_module_class(
|
||||
def derive_model_and_data_module_class( #pylint: disable=no-self-use
|
||||
self, module_file: str, data_module_file: str
|
||||
):
|
||||
"""Derives the model file and data modul file.
|
||||
|
|
|
|||
|
|
@ -12,10 +12,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for defining standard specifications and validation of parameter
|
||||
type."""
|
||||
|
||||
#pylint: disable=duplicate-code
|
||||
TRAINER_MODULE_FILE = "module_file"
|
||||
TRAINER_DATA_MODULE_FILE = "data_module_file"
|
||||
TRAINER_DATA_MODULE_ARGS = "data_module_args"
|
||||
|
|
@ -112,7 +111,7 @@ class MarGenerationSpec: # pylint: disable=R0903
|
|||
}
|
||||
|
||||
|
||||
class VisualizationSpec:
|
||||
class VisualizationSpec: #pylint: disable=too-few-public-methods
|
||||
"""Visualization Specification class.
|
||||
For validating the parameter 'type'
|
||||
"""
|
||||
|
|
@ -142,7 +141,7 @@ class VisualizationSpec:
|
|||
}
|
||||
|
||||
|
||||
class MinIoSpec:
|
||||
class MinIoSpec: #pylint: disable=too-few-public-methods
|
||||
"""MinIO Specification class.
|
||||
For validating the parameter 'type'
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ if __name__ == "__main__":
|
|||
name="pytorch-kfp-components",
|
||||
version=version,
|
||||
description="PyTorch Kubeflow Pipeline",
|
||||
url="https://github.com/kubeflow/pipelines/tree/master/components",
|
||||
url="https://github.com/kubeflow/pipelines/tree/master/components/PyTorch/pytorch-kfp-components/",
|
||||
author="The PyTorch Kubeflow Pipeline Components authors",
|
||||
author_email="pytorch-kfp-components@fb.com",
|
||||
license="Apache License 2.0",
|
||||
|
|
@ -79,7 +79,8 @@ if __name__ == "__main__":
|
|||
install_requires=make_required_install_packages(),
|
||||
dependency_links=make_dependency_links(),
|
||||
keywords=[
|
||||
"Kubeflow",
|
||||
"Kubeflow Pipelines",
|
||||
"KFP",
|
||||
"ML workflow",
|
||||
"PyTorch",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def test_unreachable_endpoint(minio_inputs):
|
|||
"""Testing unreachable minio endpoint with invalid minio creds."""
|
||||
os.environ["MINIO_ACCESS_KEY"] = "dummy"
|
||||
os.environ["MINIO_SECRET_KEY"] = "dummy"
|
||||
with pytest.raises(Exception, match="Max retries exceeded with url*"):
|
||||
with pytest.raises(Exception, match="Max retries exceeded with url: "):
|
||||
upload_to_minio(minio_inputs)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue