feat(components): Set display names for SFT, RLHF and LLM inference pipelines
PiperOrigin-RevId: 572897105
This commit is contained in:
parent
412216f832
commit
1386a826ba
|
|
@ -2,6 +2,7 @@
|
|||
* Upload tensorboard metrics from `preview.llm.rlhf_pipeline` if a `tensorboard_resource_id` is provided at runtime.
|
||||
* Support `incremental_train_base_model`, `parent_model`, `is_default_version`, `model_version_aliases`, `model_version_description` in `AutoMLImageTrainingJobRunOp`.
|
||||
* Add `preview.automl.vision` and `DataConverterJobOp`.
|
||||
* Set display names for `preview.llm` pipelines.
|
||||
|
||||
## Release 2.4.1
|
||||
* Disable caching for LLM pipeline tasks that store temporary artifacts.
|
||||
|
|
|
|||
|
|
@ -64,14 +64,14 @@ def infer_pipeline(
|
|||
machine_spec = function_based.resolve_machine_spec(
|
||||
location=location,
|
||||
use_test_spec=env.get_use_test_machine_spec(),
|
||||
)
|
||||
).set_display_name('Resolve Machine Spec')
|
||||
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
||||
large_model_reference=large_model_reference
|
||||
).set_display_name('BaseModelMetadataResolver')
|
||||
).set_display_name('Resolve Model Metadata')
|
||||
|
||||
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||
image_name='text_importer',
|
||||
).set_display_name('PromptDatasetImageUriResolver')
|
||||
).set_display_name('Resolve Prompt Dataset Image URI')
|
||||
prompt_dataset_importer = (
|
||||
private_text_importer.PrivateTextImporter(
|
||||
project=project,
|
||||
|
|
@ -86,7 +86,7 @@ def infer_pipeline(
|
|||
image_uri=prompt_dataset_image_uri.output,
|
||||
instruction=instruction,
|
||||
)
|
||||
.set_display_name('PromptDatasetImporter')
|
||||
.set_display_name('Import Prompt Dataset')
|
||||
.set_caching_options(False)
|
||||
)
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ def infer_pipeline(
|
|||
image_name='infer',
|
||||
accelerator_type=machine_spec.outputs['accelerator_type'],
|
||||
accelerator_count=machine_spec.outputs['accelerator_count'],
|
||||
).set_display_name('BulkInferrerImageUriResolver')
|
||||
).set_display_name('Resolve Bulk Inferrer Image URI')
|
||||
bulk_inference = bulk_inferrer.BulkInferrer(
|
||||
project=project,
|
||||
location=location,
|
||||
|
|
|
|||
|
|
@ -93,15 +93,15 @@ def rlhf_pipeline(
|
|||
upload_location = 'us-central1'
|
||||
machine_spec = function_based.resolve_machine_spec(
|
||||
location=location, use_test_spec=env.get_use_test_machine_spec()
|
||||
)
|
||||
).set_display_name('Resolve Machine Spec')
|
||||
|
||||
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
||||
large_model_reference=large_model_reference,
|
||||
).set_display_name('BaseModelMetadataResolver')
|
||||
).set_display_name('Resolve Model Metadata')
|
||||
|
||||
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||
image_name='text_importer'
|
||||
).set_display_name('PromptDatasetImageUriResolver')
|
||||
).set_display_name('Resolve Prompt Dataset Image URI')
|
||||
prompt_dataset_importer = (
|
||||
private_text_importer.PrivateTextImporter(
|
||||
project=project,
|
||||
|
|
@ -117,13 +117,13 @@ def rlhf_pipeline(
|
|||
image_uri=prompt_dataset_image_uri.output,
|
||||
instruction=instruction,
|
||||
)
|
||||
.set_display_name('PromptDatasetImporter')
|
||||
.set_display_name('Import Prompt Dataset')
|
||||
.set_caching_options(False)
|
||||
)
|
||||
|
||||
preference_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||
image_name='text_comparison_importer'
|
||||
).set_display_name('PreferenceDatasetImageUriResolver')
|
||||
).set_display_name('Resolve Preference Dataset Image URI')
|
||||
comma_separated_candidates_field_names = (
|
||||
function_based.convert_to_delimited_string(items=candidate_columns)
|
||||
)
|
||||
|
|
@ -142,7 +142,7 @@ def rlhf_pipeline(
|
|||
image_uri=preference_dataset_image_uri.output,
|
||||
instruction=instruction,
|
||||
)
|
||||
.set_display_name('PreferenceDatasetImporter')
|
||||
.set_display_name('Import Preference Dataset')
|
||||
.set_caching_options(False)
|
||||
)
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ def rlhf_pipeline(
|
|||
image_name='reward_model',
|
||||
accelerator_type=machine_spec.outputs['accelerator_type'],
|
||||
accelerator_count=machine_spec.outputs['accelerator_count'],
|
||||
).set_display_name('RewardModelImageUriResolver')
|
||||
).set_display_name('Resolve Reward Model Image URI')
|
||||
reward_model = (
|
||||
reward_model_trainer.RewardModelTrainer(
|
||||
project=project,
|
||||
|
|
@ -175,13 +175,13 @@ def rlhf_pipeline(
|
|||
learning_rate_multiplier=reward_model_learning_rate_multiplier,
|
||||
lora_dim=reward_model_lora_dim,
|
||||
)
|
||||
.set_display_name('RewardModelTrainer')
|
||||
.set_display_name('Reward Model Trainer')
|
||||
.set_caching_options(False)
|
||||
)
|
||||
|
||||
has_tensorboard_id = function_based.value_exists(
|
||||
value=tensorboard_resource_id
|
||||
)
|
||||
).set_display_name('Resolve Tensorboard Resource ID')
|
||||
with kfp.dsl.Condition( # pytype: disable=wrong-arg-types
|
||||
has_tensorboard_id.output == True, # pylint: disable=singleton-comparison, g-explicit-bool-comparison
|
||||
name='Upload Reward Model Tensorboard Metrics',
|
||||
|
|
@ -194,13 +194,13 @@ def rlhf_pipeline(
|
|||
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
|
||||
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
||||
),
|
||||
)
|
||||
).set_display_name('Reward Model Tensorboard Metrics Uploader')
|
||||
|
||||
rl_image_uri = function_based.resolve_private_image_uri(
|
||||
image_name='reinforcer',
|
||||
accelerator_type=machine_spec.outputs['accelerator_type'],
|
||||
accelerator_count=machine_spec.outputs['accelerator_count'],
|
||||
).set_display_name('ReinforcerImageUriResolver')
|
||||
).set_display_name('Resolve Reinforcer Image URI')
|
||||
rl_model = (
|
||||
reinforcer.Reinforcer(
|
||||
project=project,
|
||||
|
|
@ -246,9 +246,11 @@ def rlhf_pipeline(
|
|||
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
|
||||
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
||||
),
|
||||
)
|
||||
).set_display_name('Reinforcement Learning Tensorboard Metrics Uploader')
|
||||
|
||||
should_perform_inference = function_based.value_exists(value=eval_dataset)
|
||||
should_perform_inference = function_based.value_exists(
|
||||
value=eval_dataset
|
||||
).set_display_name('Resolve Inference Dataset')
|
||||
with kfp.dsl.Condition(
|
||||
should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison
|
||||
):
|
||||
|
|
@ -266,39 +268,43 @@ def rlhf_pipeline(
|
|||
adapter_artifact = kfp.dsl.importer(
|
||||
artifact_uri=rl_model.outputs['output_adapter_path'],
|
||||
artifact_class=kfp.dsl.Artifact,
|
||||
)
|
||||
).set_display_name('Import Tuned Adapter')
|
||||
regional_endpoint = function_based.resolve_regional_endpoint(
|
||||
upload_location=upload_location
|
||||
)
|
||||
).set_display_name('Resolve Regional Endpoint')
|
||||
display_name = function_based.resolve_model_display_name(
|
||||
large_model_reference=reference_model_metadata.outputs[
|
||||
'large_model_reference'
|
||||
],
|
||||
model_display_name=model_display_name,
|
||||
)
|
||||
).set_display_name('Resolve Model Display Name')
|
||||
upload_model = function_based.resolve_upload_model(
|
||||
large_model_reference=reference_model_metadata.outputs[
|
||||
'large_model_reference'
|
||||
]
|
||||
)
|
||||
upload_task = upload_llm_model.upload_llm_model(
|
||||
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
||||
location=upload_location,
|
||||
regional_endpoint=regional_endpoint.output,
|
||||
artifact_uri=adapter_artifact.output,
|
||||
model_display_name=display_name.output,
|
||||
model_reference_name='text-bison@001',
|
||||
upload_model=upload_model.output,
|
||||
).set_env_variable(
|
||||
name='VERTEX_AI_PIPELINES_RUN_LABELS',
|
||||
value=json.dumps({'tune-type': 'rlhf'}),
|
||||
).set_display_name('Resolve Upload Model')
|
||||
upload_task = (
|
||||
upload_llm_model.upload_llm_model(
|
||||
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
||||
location=upload_location,
|
||||
regional_endpoint=regional_endpoint.output,
|
||||
artifact_uri=adapter_artifact.output,
|
||||
model_display_name=display_name.output,
|
||||
model_reference_name='text-bison@001',
|
||||
upload_model=upload_model.output,
|
||||
)
|
||||
.set_env_variable(
|
||||
name='VERTEX_AI_PIPELINES_RUN_LABELS',
|
||||
value=json.dumps({'tune-type': 'rlhf'}),
|
||||
)
|
||||
.set_display_name('Upload Model')
|
||||
)
|
||||
deploy_model = function_based.resolve_deploy_model(
|
||||
deploy_model=deploy_model,
|
||||
large_model_reference=reference_model_metadata.outputs[
|
||||
'large_model_reference'
|
||||
],
|
||||
)
|
||||
).set_display_name('Resolve Deploy Model')
|
||||
deploy_task = deploy_llm_model.create_endpoint_and_deploy_model(
|
||||
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
||||
location=upload_location,
|
||||
|
|
@ -306,7 +312,7 @@ def rlhf_pipeline(
|
|||
display_name=display_name.output,
|
||||
regional_endpoint=regional_endpoint.output,
|
||||
deploy_model=deploy_model.output,
|
||||
)
|
||||
).set_display_name('Deploy Model')
|
||||
|
||||
return PipelineOutput(
|
||||
model_resource_name=upload_task.outputs['model_resource_name'],
|
||||
|
|
|
|||
Loading…
Reference in New Issue