fix(components/pytorch) Pytorch - Tensorboard Profiler fix (#5860)
* Fix: module_file_args overriding the trainer_args variable. Updating module_file_args as the superset Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Setting PTL to 1.3.5 in requirements.txt Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Fixing typo Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding print statements for profiler debugging Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing cpuonly tag Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
This commit is contained in:
parent
7607841f6a
commit
9cfa4dfc0a
|
@ -89,8 +89,8 @@ class Executor(GenericExecutor):
|
|||
if not isinstance(trainer_args, dict):
|
||||
raise TypeError("trainer_args must be a dict")
|
||||
|
||||
trainer_args.update(module_file_args)
|
||||
parser = Namespace(**trainer_args)
|
||||
module_file_args.update(trainer_args)
|
||||
parser = Namespace(**module_file_args)
|
||||
trainer = pl.Trainer.from_argparse_args(parser)
|
||||
|
||||
trainer.fit(model, data_module)
|
||||
|
|
|
@ -47,7 +47,7 @@ RUN curl -sLo ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-late
|
|||
&& conda install pip \
|
||||
&& conda clean -ya
|
||||
|
||||
RUN conda install pytorch cpuonly -c pytorch \
|
||||
RUN conda install pytorch -c pytorch \
|
||||
&& conda clean -ya
|
||||
|
||||
WORKDIR /home/user/
|
||||
|
|
|
@ -26,6 +26,8 @@ from pytorch_kfp_components.components.visualization.component import Visualizat
|
|||
from pytorch_kfp_components.components.trainer.component import Trainer
|
||||
from pytorch_kfp_components.components.mar.component import MarGeneration
|
||||
# Argument parser for user defined paths
|
||||
import pytorch_lightning
|
||||
print("Using Pytorch Lighting: {}".format(pytorch_lightning.__version__))
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -136,6 +138,11 @@ trainer = Trainer(
|
|||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
print("Generated tensorboard files")
|
||||
for root, dirs, files in os.walk(args["tensorboard_root"]): # pylint: disable=unused-variable
|
||||
for file in files:
|
||||
print(file)
|
||||
|
||||
model = trainer.ptl_trainer.get_model()
|
||||
|
||||
if trainer.ptl_trainer.global_rank == 0:
|
||||
|
@ -219,3 +226,4 @@ if trainer.ptl_trainer.global_rank == 0:
|
|||
mlpipeline_metrics=args["mlpipeline_metrics"],
|
||||
markdown=markdown_dict,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue