feat(components): Set display names for SFT, RLHF and LLM inference pipelines

PiperOrigin-RevId: 572897105
This commit is contained in:
Googler 2023-10-12 07:16:54 -07:00 committed by Google Cloud Pipeline Components maintainers
parent 412216f832
commit 1386a826ba
3 changed files with 42 additions and 35 deletions

View File

@ -2,6 +2,7 @@
* Upload tensorboard metrics from `preview.llm.rlhf_pipeline` if a `tensorboard_resource_id` is provided at runtime. * Upload tensorboard metrics from `preview.llm.rlhf_pipeline` if a `tensorboard_resource_id` is provided at runtime.
* Support `incremental_train_base_model`, `parent_model`, `is_default_version`, `model_version_aliases`, `model_version_description` in `AutoMLImageTrainingJobRunOp`. * Support `incremental_train_base_model`, `parent_model`, `is_default_version`, `model_version_aliases`, `model_version_description` in `AutoMLImageTrainingJobRunOp`.
* Add `preview.automl.vision` and `DataConverterJobOp`. * Add `preview.automl.vision` and `DataConverterJobOp`.
* Set display names for `preview.llm` pipelines.
## Release 2.4.1 ## Release 2.4.1
* Disable caching for LLM pipeline tasks that store temporary artifacts. * Disable caching for LLM pipeline tasks that store temporary artifacts.

View File

@ -64,14 +64,14 @@ def infer_pipeline(
machine_spec = function_based.resolve_machine_spec( machine_spec = function_based.resolve_machine_spec(
location=location, location=location,
use_test_spec=env.get_use_test_machine_spec(), use_test_spec=env.get_use_test_machine_spec(),
) ).set_display_name('Resolve Machine Spec')
reference_model_metadata = function_based.resolve_reference_model_metadata( reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference large_model_reference=large_model_reference
).set_display_name('BaseModelMetadataResolver') ).set_display_name('Resolve Model Metadata')
prompt_dataset_image_uri = function_based.resolve_private_image_uri( prompt_dataset_image_uri = function_based.resolve_private_image_uri(
image_name='text_importer', image_name='text_importer',
).set_display_name('PromptDatasetImageUriResolver') ).set_display_name('Resolve Prompt Dataset Image URI')
prompt_dataset_importer = ( prompt_dataset_importer = (
private_text_importer.PrivateTextImporter( private_text_importer.PrivateTextImporter(
project=project, project=project,
@ -86,7 +86,7 @@ def infer_pipeline(
image_uri=prompt_dataset_image_uri.output, image_uri=prompt_dataset_image_uri.output,
instruction=instruction, instruction=instruction,
) )
.set_display_name('PromptDatasetImporter') .set_display_name('Import Prompt Dataset')
.set_caching_options(False) .set_caching_options(False)
) )
@ -94,7 +94,7 @@ def infer_pipeline(
image_name='infer', image_name='infer',
accelerator_type=machine_spec.outputs['accelerator_type'], accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'], accelerator_count=machine_spec.outputs['accelerator_count'],
).set_display_name('BulkInferrerImageUriResolver') ).set_display_name('Resolve Bulk Inferrer Image URI')
bulk_inference = bulk_inferrer.BulkInferrer( bulk_inference = bulk_inferrer.BulkInferrer(
project=project, project=project,
location=location, location=location,

View File

@ -93,15 +93,15 @@ def rlhf_pipeline(
upload_location = 'us-central1' upload_location = 'us-central1'
machine_spec = function_based.resolve_machine_spec( machine_spec = function_based.resolve_machine_spec(
location=location, use_test_spec=env.get_use_test_machine_spec() location=location, use_test_spec=env.get_use_test_machine_spec()
) ).set_display_name('Resolve Machine Spec')
reference_model_metadata = function_based.resolve_reference_model_metadata( reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference, large_model_reference=large_model_reference,
).set_display_name('BaseModelMetadataResolver') ).set_display_name('Resolve Model Metadata')
prompt_dataset_image_uri = function_based.resolve_private_image_uri( prompt_dataset_image_uri = function_based.resolve_private_image_uri(
image_name='text_importer' image_name='text_importer'
).set_display_name('PromptDatasetImageUriResolver') ).set_display_name('Resolve Prompt Dataset Image URI')
prompt_dataset_importer = ( prompt_dataset_importer = (
private_text_importer.PrivateTextImporter( private_text_importer.PrivateTextImporter(
project=project, project=project,
@ -117,13 +117,13 @@ def rlhf_pipeline(
image_uri=prompt_dataset_image_uri.output, image_uri=prompt_dataset_image_uri.output,
instruction=instruction, instruction=instruction,
) )
.set_display_name('PromptDatasetImporter') .set_display_name('Import Prompt Dataset')
.set_caching_options(False) .set_caching_options(False)
) )
preference_dataset_image_uri = function_based.resolve_private_image_uri( preference_dataset_image_uri = function_based.resolve_private_image_uri(
image_name='text_comparison_importer' image_name='text_comparison_importer'
).set_display_name('PreferenceDatasetImageUriResolver') ).set_display_name('Resolve Preference Dataset Image URI')
comma_separated_candidates_field_names = ( comma_separated_candidates_field_names = (
function_based.convert_to_delimited_string(items=candidate_columns) function_based.convert_to_delimited_string(items=candidate_columns)
) )
@ -142,7 +142,7 @@ def rlhf_pipeline(
image_uri=preference_dataset_image_uri.output, image_uri=preference_dataset_image_uri.output,
instruction=instruction, instruction=instruction,
) )
.set_display_name('PreferenceDatasetImporter') .set_display_name('Import Preference Dataset')
.set_caching_options(False) .set_caching_options(False)
) )
@ -150,7 +150,7 @@ def rlhf_pipeline(
image_name='reward_model', image_name='reward_model',
accelerator_type=machine_spec.outputs['accelerator_type'], accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'], accelerator_count=machine_spec.outputs['accelerator_count'],
).set_display_name('RewardModelImageUriResolver') ).set_display_name('Resolve Reward Model Image URI')
reward_model = ( reward_model = (
reward_model_trainer.RewardModelTrainer( reward_model_trainer.RewardModelTrainer(
project=project, project=project,
@ -175,13 +175,13 @@ def rlhf_pipeline(
learning_rate_multiplier=reward_model_learning_rate_multiplier, learning_rate_multiplier=reward_model_learning_rate_multiplier,
lora_dim=reward_model_lora_dim, lora_dim=reward_model_lora_dim,
) )
.set_display_name('RewardModelTrainer') .set_display_name('Reward Model Trainer')
.set_caching_options(False) .set_caching_options(False)
) )
has_tensorboard_id = function_based.value_exists( has_tensorboard_id = function_based.value_exists(
value=tensorboard_resource_id value=tensorboard_resource_id
) ).set_display_name('Resolve Tensorboard Resource ID')
with kfp.dsl.Condition( # pytype: disable=wrong-arg-types with kfp.dsl.Condition( # pytype: disable=wrong-arg-types
has_tensorboard_id.output == True, # pylint: disable=singleton-comparison, g-explicit-bool-comparison has_tensorboard_id.output == True, # pylint: disable=singleton-comparison, g-explicit-bool-comparison
name='Upload Reward Model Tensorboard Metrics', name='Upload Reward Model Tensorboard Metrics',
@ -194,13 +194,13 @@ def rlhf_pipeline(
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-' f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}' f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
), ),
) ).set_display_name('Reward Model Tensorboard Metrics Uploader')
rl_image_uri = function_based.resolve_private_image_uri( rl_image_uri = function_based.resolve_private_image_uri(
image_name='reinforcer', image_name='reinforcer',
accelerator_type=machine_spec.outputs['accelerator_type'], accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'], accelerator_count=machine_spec.outputs['accelerator_count'],
).set_display_name('ReinforcerImageUriResolver') ).set_display_name('Resolve Reinforcer Image URI')
rl_model = ( rl_model = (
reinforcer.Reinforcer( reinforcer.Reinforcer(
project=project, project=project,
@ -246,9 +246,11 @@ def rlhf_pipeline(
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-' f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}' f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
), ),
) ).set_display_name('Reinforcement Learning Tensorboard Metrics Uploader')
should_perform_inference = function_based.value_exists(value=eval_dataset) should_perform_inference = function_based.value_exists(
value=eval_dataset
).set_display_name('Resolve Inference Dataset')
with kfp.dsl.Condition( with kfp.dsl.Condition(
should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison
): ):
@ -266,39 +268,43 @@ def rlhf_pipeline(
adapter_artifact = kfp.dsl.importer( adapter_artifact = kfp.dsl.importer(
artifact_uri=rl_model.outputs['output_adapter_path'], artifact_uri=rl_model.outputs['output_adapter_path'],
artifact_class=kfp.dsl.Artifact, artifact_class=kfp.dsl.Artifact,
) ).set_display_name('Import Tuned Adapter')
regional_endpoint = function_based.resolve_regional_endpoint( regional_endpoint = function_based.resolve_regional_endpoint(
upload_location=upload_location upload_location=upload_location
) ).set_display_name('Resolve Regional Endpoint')
display_name = function_based.resolve_model_display_name( display_name = function_based.resolve_model_display_name(
large_model_reference=reference_model_metadata.outputs[ large_model_reference=reference_model_metadata.outputs[
'large_model_reference' 'large_model_reference'
], ],
model_display_name=model_display_name, model_display_name=model_display_name,
) ).set_display_name('Resolve Model Display Name')
upload_model = function_based.resolve_upload_model( upload_model = function_based.resolve_upload_model(
large_model_reference=reference_model_metadata.outputs[ large_model_reference=reference_model_metadata.outputs[
'large_model_reference' 'large_model_reference'
] ]
) ).set_display_name('Resolve Upload Model')
upload_task = upload_llm_model.upload_llm_model( upload_task = (
project=_placeholders.PROJECT_ID_PLACEHOLDER, upload_llm_model.upload_llm_model(
location=upload_location, project=_placeholders.PROJECT_ID_PLACEHOLDER,
regional_endpoint=regional_endpoint.output, location=upload_location,
artifact_uri=adapter_artifact.output, regional_endpoint=regional_endpoint.output,
model_display_name=display_name.output, artifact_uri=adapter_artifact.output,
model_reference_name='text-bison@001', model_display_name=display_name.output,
upload_model=upload_model.output, model_reference_name='text-bison@001',
).set_env_variable( upload_model=upload_model.output,
name='VERTEX_AI_PIPELINES_RUN_LABELS', )
value=json.dumps({'tune-type': 'rlhf'}), .set_env_variable(
name='VERTEX_AI_PIPELINES_RUN_LABELS',
value=json.dumps({'tune-type': 'rlhf'}),
)
.set_display_name('Upload Model')
) )
deploy_model = function_based.resolve_deploy_model( deploy_model = function_based.resolve_deploy_model(
deploy_model=deploy_model, deploy_model=deploy_model,
large_model_reference=reference_model_metadata.outputs[ large_model_reference=reference_model_metadata.outputs[
'large_model_reference' 'large_model_reference'
], ],
) ).set_display_name('Resolve Deploy Model')
deploy_task = deploy_llm_model.create_endpoint_and_deploy_model( deploy_task = deploy_llm_model.create_endpoint_and_deploy_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER, project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=upload_location, location=upload_location,
@ -306,7 +312,7 @@ def rlhf_pipeline(
display_name=display_name.output, display_name=display_name.output,
regional_endpoint=regional_endpoint.output, regional_endpoint=regional_endpoint.output,
deploy_model=deploy_model.output, deploy_model=deploy_model.output,
) ).set_display_name('Deploy Model')
return PipelineOutput( return PipelineOutput(
model_resource_name=upload_task.outputs['model_resource_name'], model_resource_name=upload_task.outputs['model_resource_name'],