From 755c1f9898b3c1e1c539403d43e27a3ea3994447 Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 27 Feb 2024 16:53:03 -0800 Subject: [PATCH] fix(components): Pass tuned model checkpoint to inference pipeline after RLHF tuning PiperOrigin-RevId: 610918020 --- components/google-cloud/RELEASE.md | 1 + .../preview/llm/rlhf/component.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/components/google-cloud/RELEASE.md b/components/google-cloud/RELEASE.md index 63561ac05f..8af6583a90 100644 --- a/components/google-cloud/RELEASE.md +++ b/components/google-cloud/RELEASE.md @@ -1,5 +1,6 @@ ## Upcoming release * Add `v1.automl.forecasting.learn_to_learn_forecasting_pipeline`, `v1.automl.forecasting.sequence_to_sequence_forecasting_pipeline`, `v1.automl.forecasting.temporal_fusion_transformer_forecasting_pipeline`, `v1.automl.forecasting.time_series_dense_encoder_forecasting_pipeline` as Forecasting on Pipelines moves to GA. +* Fix bug in `preview.llm.rlhf_pipeline` that caused wrong output artifact to be used for inference after training. ## Release 2.10.0 * Fix the missing output of pipeline remote runner. `AutoMLImageTrainingJobRunOp` now passes the model artifacts correctly to downstream components. diff --git a/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py b/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py index b089673674..4e5eddd44f 100644 --- a/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py @@ -152,7 +152,7 @@ def rlhf_pipeline( name='Perform Inference', ): has_model_checkpoint = function_based.value_exists( - value=rl_model_pipeline.outputs['output_adapter_path'] + value=rl_model_pipeline.outputs['output_model_path'] ).set_display_name('Resolve Model Checkpoint') with kfp.dsl.Condition( has_model_checkpoint.output == True, # pylint: disable=singleton-comparison @@ -162,7 +162,7 @@ def rlhf_pipeline( project=project, location=location, large_model_reference=large_model_reference, - model_checkpoint=rl_model_pipeline.outputs['output_adapter_path'], + model_checkpoint=rl_model_pipeline.outputs['output_model_path'], prompt_dataset=eval_dataset, prompt_sequence_length=prompt_sequence_length, target_sequence_length=target_sequence_length,