# !/usr/bin/env/python3 # Copyright (c) Facebook, Inc. and its affiliates. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """AG news Classification script.""" import os from argparse import ArgumentParser from pathlib import Path from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, ) from pytorch_kfp_components.components.visualization.component import Visualization from pytorch_kfp_components.components.trainer.component import Trainer from pytorch_kfp_components.components.mar.component import MarGeneration from pytorch_kfp_components.components.utils.argument_parsing import parse_input_args # Argument parser for user defined paths import pytorch_lightning print("Using Pytorch Lighting: {}".format(pytorch_lightning.__version__)) #pylint: disable=no-member parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="output/processing", help="Path to input dataset", ) parser.add_argument( "--mlpipeline_ui_metadata", type=str, default="mlpipeline-ui-metadata.json", help="Path to write mlpipeline-ui-metadata.json", ) parser.add_argument( "--mlpipeline_metrics", type=str, default="mlpipeline-metrics", help="Path to write mlpipeline-metrics.json", ) parser.add_argument( "--script_args", type=str, help="Arguments for bert agnews classification script", ) parser.add_argument( "--ptl_args", type=str, help="Arguments specific to PTL trainer", ) parser.add_argument( "--checkpoint_dir", default="output/train/models", type=str, help="Arguments specific to PTL trainer", ) parser.add_argument( "--tensorboard_root", default="output/tensorboard", type=str, help="Arguments specific to PTL trainer", ) args = vars(parser.parse_args()) script_args = args["script_args"] ptl_args = args["ptl_args"] TENSOBOARD_ROOT = args["tensorboard_root"] CHECKPOINT_DIR = args["checkpoint_dir"] DATASET_PATH = args["dataset_path"] script_dict: dict = parse_input_args(input_str=script_args) script_dict["checkpoint_dir"] = CHECKPOINT_DIR ptl_dict: dict = parse_input_args(input_str=ptl_args) # Enabling Tensorboard Logger, ModelCheckpoint, Earlystopping lr_logger = LearningRateMonitor() tboard = TensorBoardLogger(TENSOBOARD_ROOT) early_stopping = EarlyStopping( monitor="val_loss", mode="min", patience=5, verbose=True ) checkpoint_callback = ModelCheckpoint( dirpath=CHECKPOINT_DIR, filename="bert_{epoch:02d}", save_top_k=1, verbose=True, monitor="val_loss", mode="min", ) if "accelerator" in ptl_dict and ptl_dict["accelerator"] == "None": ptl_dict["accelerator"] = None # Setting the trainer specific arguments trainer_args = { "logger": tboard, "enable_checkpointing": False, "callbacks": [lr_logger, early_stopping], } if not ptl_dict["max_epochs"]: trainer_args["max_epochs"] = 1 else: trainer_args["max_epochs"] = ptl_dict["max_epochs"] if "profiler" in ptl_dict and ptl_dict["profiler"] != "": trainer_args["profiler"] = ptl_dict["profiler"] # Setting the datamodule specific arguments data_module_args = { "train_glob": DATASET_PATH, "num_samples": script_dict["num_samples"] } # Creating parent directories Path(TENSOBOARD_ROOT).mkdir(parents=True, exist_ok=True) Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) # Updating all the input parameter to PTL dict trainer_args.update(ptl_dict) # Initiating the training process trainer = Trainer( module_file="bert_train.py", data_module_file="bert_datamodule.py", module_file_args=script_dict, data_module_args=data_module_args, trainer_args=trainer_args, ) print("Generated tensorboard files") for root, dirs, files in os.walk(args["tensorboard_root"]): # pylint: disable=unused-variable for file in files: print(file) model = trainer.ptl_trainer.lightning_module if trainer.ptl_trainer.global_rank == 0: # Mar file generation bert_dir, _ = os.path.split(os.path.abspath(__file__)) mar_config = { "MODEL_NAME": "bert_test", "MODEL_FILE": os.path.join(bert_dir, "bert_train.py"), "HANDLER": os.path.join(bert_dir, "bert_handler.py"), "SERIALIZED_FILE": os.path.join(CHECKPOINT_DIR, script_dict["model_name"]), "VERSION": "1", "EXPORT_PATH": CHECKPOINT_DIR, "CONFIG_PROPERTIES": os.path.join(bert_dir, "config.properties"), "EXTRA_FILES": "{},{},{}".format( os.path.join(bert_dir, "bert-base-uncased-vocab.txt"), os.path.join(bert_dir, "index_to_name.json"), os.path.join(bert_dir, "wrapper.py") ), "REQUIREMENTS_FILE": os.path.join(bert_dir, "requirements.txt") } MarGeneration(mar_config=mar_config, mar_save_path=CHECKPOINT_DIR) print("Generated checkpoint files") for root, dirs, files in os.walk(CHECKPOINT_DIR): # pylint: disable=unused-variable for file in files: path = os.path.join(root, file) size = os.stat(path).st_size # in bytes if ".pth" in file: print("Removing file: ", path) os.remove(path) classes = [ "World", "Sports", "Business", "Sci/Tech", ] # model = trainer.ptl_trainer.model target_index_list = list(set(model.target)) class_list = [] for index in target_index_list: class_list.append(classes[index]) confusion_matrix_dict = { "actuals": model.target, "preds": model.preds, "classes": class_list, "url": script_dict["confusion_matrix_url"], } test_accuracy = round(float(model.test_acc.compute()), 2) print("Model test accuracy: ", test_accuracy) visualization_arguments = { "input": { "tensorboard_root": TENSOBOARD_ROOT, "checkpoint_dir": CHECKPOINT_DIR, "dataset_path": DATASET_PATH, "model_name": script_dict["model_name"], "confusion_matrix_url": script_dict["confusion_matrix_url"], }, "output": { "mlpipeline_ui_metadata": args["mlpipeline_ui_metadata"], "mlpipeline_metrics": args["mlpipeline_metrics"], }, } markdown_dict = {"storage": "inline", "source": visualization_arguments} print("Visualization Arguments: ", markdown_dict) visualization = Visualization( test_accuracy=test_accuracy, confusion_matrix_dict=confusion_matrix_dict, mlpipeline_ui_metadata=args["mlpipeline_ui_metadata"], mlpipeline_metrics=args["mlpipeline_metrics"], markdown=markdown_dict, )