fix(components/pytorch) Pytorch - Tensorboard Profiler fix (#5860)

* Fix: module_file_args overriding the trainer_args variable. Updating module_file_args as the superset

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Setting PTL to 1.3.5 in requirements.txt

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing typo

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding print statements for profiler debugging

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing cpuonly tag

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
This commit is contained in:
shrinath-suresh 2021-06-17 04:36:27 +05:30 committed by GitHub
parent 7607841f6a
commit 9cfa4dfc0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 3 deletions

View File

@ -89,8 +89,8 @@ class Executor(GenericExecutor):
if not isinstance(trainer_args, dict):
raise TypeError("trainer_args must be a dict")
trainer_args.update(module_file_args)
parser = Namespace(**trainer_args)
module_file_args.update(trainer_args)
parser = Namespace(**module_file_args)
trainer = pl.Trainer.from_argparse_args(parser)
trainer.fit(model, data_module)

View File

@ -47,7 +47,7 @@ RUN curl -sLo ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-late
&& conda install pip \
&& conda clean -ya
RUN conda install pytorch cpuonly -c pytorch \
RUN conda install pytorch -c pytorch \
&& conda clean -ya
WORKDIR /home/user/

View File

@ -26,6 +26,8 @@ from pytorch_kfp_components.components.visualization.component import Visualizat
from pytorch_kfp_components.components.trainer.component import Trainer
from pytorch_kfp_components.components.mar.component import MarGeneration
# Argument parser for user defined paths
import pytorch_lightning
print("Using Pytorch Lighting: {}".format(pytorch_lightning.__version__))
parser = ArgumentParser()
parser.add_argument(
@ -136,6 +138,11 @@ trainer = Trainer(
trainer_args=trainer_args,
)
print("Generated tensorboard files")
for root, dirs, files in os.walk(args["tensorboard_root"]): # pylint: disable=unused-variable
for file in files:
print(file)
model = trainer.ptl_trainer.get_model()
if trainer.ptl_trainer.global_rank == 0:
@ -219,3 +226,4 @@ if trainer.ptl_trainer.global_rank == 0:
mlpipeline_metrics=args["mlpipeline_metrics"],
markdown=markdown_dict,
)