feat(components): Set display names for SFT, RLHF and LLM inference pipelines

PiperOrigin-RevId: 572897105
This commit is contained in:
Googler 2023-10-12 07:16:54 -07:00 committed by Google Cloud Pipeline Components maintainers
parent 412216f832
commit 1386a826ba
3 changed files with 42 additions and 35 deletions

View File

@ -2,6 +2,7 @@
* Upload tensorboard metrics from `preview.llm.rlhf_pipeline` if a `tensorboard_resource_id` is provided at runtime.
* Support `incremental_train_base_model`, `parent_model`, `is_default_version`, `model_version_aliases`, `model_version_description` in `AutoMLImageTrainingJobRunOp`.
* Add `preview.automl.vision` and `DataConverterJobOp`.
* Set display names for `preview.llm` pipelines.
## Release 2.4.1
* Disable caching for LLM pipeline tasks that store temporary artifacts.

View File

@ -64,14 +64,14 @@ def infer_pipeline(
machine_spec = function_based.resolve_machine_spec(
location=location,
use_test_spec=env.get_use_test_machine_spec(),
)
).set_display_name('Resolve Machine Spec')
reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference
).set_display_name('BaseModelMetadataResolver')
).set_display_name('Resolve Model Metadata')
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
image_name='text_importer',
).set_display_name('PromptDatasetImageUriResolver')
).set_display_name('Resolve Prompt Dataset Image URI')
prompt_dataset_importer = (
private_text_importer.PrivateTextImporter(
project=project,
@ -86,7 +86,7 @@ def infer_pipeline(
image_uri=prompt_dataset_image_uri.output,
instruction=instruction,
)
.set_display_name('PromptDatasetImporter')
.set_display_name('Import Prompt Dataset')
.set_caching_options(False)
)
@ -94,7 +94,7 @@ def infer_pipeline(
image_name='infer',
accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'],
).set_display_name('BulkInferrerImageUriResolver')
).set_display_name('Resolve Bulk Inferrer Image URI')
bulk_inference = bulk_inferrer.BulkInferrer(
project=project,
location=location,

View File

@ -93,15 +93,15 @@ def rlhf_pipeline(
upload_location = 'us-central1'
machine_spec = function_based.resolve_machine_spec(
location=location, use_test_spec=env.get_use_test_machine_spec()
)
).set_display_name('Resolve Machine Spec')
reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference,
).set_display_name('BaseModelMetadataResolver')
).set_display_name('Resolve Model Metadata')
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
image_name='text_importer'
).set_display_name('PromptDatasetImageUriResolver')
).set_display_name('Resolve Prompt Dataset Image URI')
prompt_dataset_importer = (
private_text_importer.PrivateTextImporter(
project=project,
@ -117,13 +117,13 @@ def rlhf_pipeline(
image_uri=prompt_dataset_image_uri.output,
instruction=instruction,
)
.set_display_name('PromptDatasetImporter')
.set_display_name('Import Prompt Dataset')
.set_caching_options(False)
)
preference_dataset_image_uri = function_based.resolve_private_image_uri(
image_name='text_comparison_importer'
).set_display_name('PreferenceDatasetImageUriResolver')
).set_display_name('Resolve Preference Dataset Image URI')
comma_separated_candidates_field_names = (
function_based.convert_to_delimited_string(items=candidate_columns)
)
@ -142,7 +142,7 @@ def rlhf_pipeline(
image_uri=preference_dataset_image_uri.output,
instruction=instruction,
)
.set_display_name('PreferenceDatasetImporter')
.set_display_name('Import Preference Dataset')
.set_caching_options(False)
)
@ -150,7 +150,7 @@ def rlhf_pipeline(
image_name='reward_model',
accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'],
).set_display_name('RewardModelImageUriResolver')
).set_display_name('Resolve Reward Model Image URI')
reward_model = (
reward_model_trainer.RewardModelTrainer(
project=project,
@ -175,13 +175,13 @@ def rlhf_pipeline(
learning_rate_multiplier=reward_model_learning_rate_multiplier,
lora_dim=reward_model_lora_dim,
)
.set_display_name('RewardModelTrainer')
.set_display_name('Reward Model Trainer')
.set_caching_options(False)
)
has_tensorboard_id = function_based.value_exists(
value=tensorboard_resource_id
)
).set_display_name('Resolve Tensorboard Resource ID')
with kfp.dsl.Condition( # pytype: disable=wrong-arg-types
has_tensorboard_id.output == True, # pylint: disable=singleton-comparison, g-explicit-bool-comparison
name='Upload Reward Model Tensorboard Metrics',
@ -194,13 +194,13 @@ def rlhf_pipeline(
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
),
)
).set_display_name('Reward Model Tensorboard Metrics Uploader')
rl_image_uri = function_based.resolve_private_image_uri(
image_name='reinforcer',
accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'],
).set_display_name('ReinforcerImageUriResolver')
).set_display_name('Resolve Reinforcer Image URI')
rl_model = (
reinforcer.Reinforcer(
project=project,
@ -246,9 +246,11 @@ def rlhf_pipeline(
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
),
)
).set_display_name('Reinforcement Learning Tensorboard Metrics Uploader')
should_perform_inference = function_based.value_exists(value=eval_dataset)
should_perform_inference = function_based.value_exists(
value=eval_dataset
).set_display_name('Resolve Inference Dataset')
with kfp.dsl.Condition(
should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison
):
@ -266,39 +268,43 @@ def rlhf_pipeline(
adapter_artifact = kfp.dsl.importer(
artifact_uri=rl_model.outputs['output_adapter_path'],
artifact_class=kfp.dsl.Artifact,
)
).set_display_name('Import Tuned Adapter')
regional_endpoint = function_based.resolve_regional_endpoint(
upload_location=upload_location
)
).set_display_name('Resolve Regional Endpoint')
display_name = function_based.resolve_model_display_name(
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
],
model_display_name=model_display_name,
)
).set_display_name('Resolve Model Display Name')
upload_model = function_based.resolve_upload_model(
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
]
)
upload_task = upload_llm_model.upload_llm_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=upload_location,
regional_endpoint=regional_endpoint.output,
artifact_uri=adapter_artifact.output,
model_display_name=display_name.output,
model_reference_name='text-bison@001',
upload_model=upload_model.output,
).set_env_variable(
name='VERTEX_AI_PIPELINES_RUN_LABELS',
value=json.dumps({'tune-type': 'rlhf'}),
).set_display_name('Resolve Upload Model')
upload_task = (
upload_llm_model.upload_llm_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=upload_location,
regional_endpoint=regional_endpoint.output,
artifact_uri=adapter_artifact.output,
model_display_name=display_name.output,
model_reference_name='text-bison@001',
upload_model=upload_model.output,
)
.set_env_variable(
name='VERTEX_AI_PIPELINES_RUN_LABELS',
value=json.dumps({'tune-type': 'rlhf'}),
)
.set_display_name('Upload Model')
)
deploy_model = function_based.resolve_deploy_model(
deploy_model=deploy_model,
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
],
)
).set_display_name('Resolve Deploy Model')
deploy_task = deploy_llm_model.create_endpoint_and_deploy_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=upload_location,
@ -306,7 +312,7 @@ def rlhf_pipeline(
display_name=display_name.output,
regional_endpoint=regional_endpoint.output,
deploy_model=deploy_model.output,
)
).set_display_name('Deploy Model')
return PipelineOutput(
model_resource_name=upload_task.outputs['model_resource_name'],