fix(components/pytorch) Pytorch - Tensorboard Profiler fix (#5860)
* Fix: module_file_args overriding the trainer_args variable. Updating module_file_args as the superset Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Setting PTL to 1.3.5 in requirements.txt Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Fixing typo Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding print statements for profiler debugging Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing cpuonly tag Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
This commit is contained in:
parent
7607841f6a
commit
9cfa4dfc0a
|
@ -89,8 +89,8 @@ class Executor(GenericExecutor):
|
||||||
if not isinstance(trainer_args, dict):
|
if not isinstance(trainer_args, dict):
|
||||||
raise TypeError("trainer_args must be a dict")
|
raise TypeError("trainer_args must be a dict")
|
||||||
|
|
||||||
trainer_args.update(module_file_args)
|
module_file_args.update(trainer_args)
|
||||||
parser = Namespace(**trainer_args)
|
parser = Namespace(**module_file_args)
|
||||||
trainer = pl.Trainer.from_argparse_args(parser)
|
trainer = pl.Trainer.from_argparse_args(parser)
|
||||||
|
|
||||||
trainer.fit(model, data_module)
|
trainer.fit(model, data_module)
|
||||||
|
|
|
@ -47,7 +47,7 @@ RUN curl -sLo ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-late
|
||||||
&& conda install pip \
|
&& conda install pip \
|
||||||
&& conda clean -ya
|
&& conda clean -ya
|
||||||
|
|
||||||
RUN conda install pytorch cpuonly -c pytorch \
|
RUN conda install pytorch -c pytorch \
|
||||||
&& conda clean -ya
|
&& conda clean -ya
|
||||||
|
|
||||||
WORKDIR /home/user/
|
WORKDIR /home/user/
|
||||||
|
|
|
@ -26,6 +26,8 @@ from pytorch_kfp_components.components.visualization.component import Visualizat
|
||||||
from pytorch_kfp_components.components.trainer.component import Trainer
|
from pytorch_kfp_components.components.trainer.component import Trainer
|
||||||
from pytorch_kfp_components.components.mar.component import MarGeneration
|
from pytorch_kfp_components.components.mar.component import MarGeneration
|
||||||
# Argument parser for user defined paths
|
# Argument parser for user defined paths
|
||||||
|
import pytorch_lightning
|
||||||
|
print("Using Pytorch Lighting: {}".format(pytorch_lightning.__version__))
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -136,6 +138,11 @@ trainer = Trainer(
|
||||||
trainer_args=trainer_args,
|
trainer_args=trainer_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Generated tensorboard files")
|
||||||
|
for root, dirs, files in os.walk(args["tensorboard_root"]): # pylint: disable=unused-variable
|
||||||
|
for file in files:
|
||||||
|
print(file)
|
||||||
|
|
||||||
model = trainer.ptl_trainer.get_model()
|
model = trainer.ptl_trainer.get_model()
|
||||||
|
|
||||||
if trainer.ptl_trainer.global_rank == 0:
|
if trainer.ptl_trainer.global_rank == 0:
|
||||||
|
@ -219,3 +226,4 @@ if trainer.ptl_trainer.global_rank == 0:
|
||||||
mlpipeline_metrics=args["mlpipeline_metrics"],
|
mlpipeline_metrics=args["mlpipeline_metrics"],
|
||||||
markdown=markdown_dict,
|
markdown=markdown_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue