feat(components): Set display names for SFT, RLHF and LLM inference pipelines
PiperOrigin-RevId: 572897105
This commit is contained in:
parent
412216f832
commit
1386a826ba
|
|
@ -2,6 +2,7 @@
|
||||||
* Upload tensorboard metrics from `preview.llm.rlhf_pipeline` if a `tensorboard_resource_id` is provided at runtime.
|
* Upload tensorboard metrics from `preview.llm.rlhf_pipeline` if a `tensorboard_resource_id` is provided at runtime.
|
||||||
* Support `incremental_train_base_model`, `parent_model`, `is_default_version`, `model_version_aliases`, `model_version_description` in `AutoMLImageTrainingJobRunOp`.
|
* Support `incremental_train_base_model`, `parent_model`, `is_default_version`, `model_version_aliases`, `model_version_description` in `AutoMLImageTrainingJobRunOp`.
|
||||||
* Add `preview.automl.vision` and `DataConverterJobOp`.
|
* Add `preview.automl.vision` and `DataConverterJobOp`.
|
||||||
|
* Set display names for `preview.llm` pipelines.
|
||||||
|
|
||||||
## Release 2.4.1
|
## Release 2.4.1
|
||||||
* Disable caching for LLM pipeline tasks that store temporary artifacts.
|
* Disable caching for LLM pipeline tasks that store temporary artifacts.
|
||||||
|
|
|
||||||
|
|
@ -64,14 +64,14 @@ def infer_pipeline(
|
||||||
machine_spec = function_based.resolve_machine_spec(
|
machine_spec = function_based.resolve_machine_spec(
|
||||||
location=location,
|
location=location,
|
||||||
use_test_spec=env.get_use_test_machine_spec(),
|
use_test_spec=env.get_use_test_machine_spec(),
|
||||||
)
|
).set_display_name('Resolve Machine Spec')
|
||||||
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
||||||
large_model_reference=large_model_reference
|
large_model_reference=large_model_reference
|
||||||
).set_display_name('BaseModelMetadataResolver')
|
).set_display_name('Resolve Model Metadata')
|
||||||
|
|
||||||
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||||
image_name='text_importer',
|
image_name='text_importer',
|
||||||
).set_display_name('PromptDatasetImageUriResolver')
|
).set_display_name('Resolve Prompt Dataset Image URI')
|
||||||
prompt_dataset_importer = (
|
prompt_dataset_importer = (
|
||||||
private_text_importer.PrivateTextImporter(
|
private_text_importer.PrivateTextImporter(
|
||||||
project=project,
|
project=project,
|
||||||
|
|
@ -86,7 +86,7 @@ def infer_pipeline(
|
||||||
image_uri=prompt_dataset_image_uri.output,
|
image_uri=prompt_dataset_image_uri.output,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
)
|
)
|
||||||
.set_display_name('PromptDatasetImporter')
|
.set_display_name('Import Prompt Dataset')
|
||||||
.set_caching_options(False)
|
.set_caching_options(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -94,7 +94,7 @@ def infer_pipeline(
|
||||||
image_name='infer',
|
image_name='infer',
|
||||||
accelerator_type=machine_spec.outputs['accelerator_type'],
|
accelerator_type=machine_spec.outputs['accelerator_type'],
|
||||||
accelerator_count=machine_spec.outputs['accelerator_count'],
|
accelerator_count=machine_spec.outputs['accelerator_count'],
|
||||||
).set_display_name('BulkInferrerImageUriResolver')
|
).set_display_name('Resolve Bulk Inferrer Image URI')
|
||||||
bulk_inference = bulk_inferrer.BulkInferrer(
|
bulk_inference = bulk_inferrer.BulkInferrer(
|
||||||
project=project,
|
project=project,
|
||||||
location=location,
|
location=location,
|
||||||
|
|
|
||||||
|
|
@ -93,15 +93,15 @@ def rlhf_pipeline(
|
||||||
upload_location = 'us-central1'
|
upload_location = 'us-central1'
|
||||||
machine_spec = function_based.resolve_machine_spec(
|
machine_spec = function_based.resolve_machine_spec(
|
||||||
location=location, use_test_spec=env.get_use_test_machine_spec()
|
location=location, use_test_spec=env.get_use_test_machine_spec()
|
||||||
)
|
).set_display_name('Resolve Machine Spec')
|
||||||
|
|
||||||
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
||||||
large_model_reference=large_model_reference,
|
large_model_reference=large_model_reference,
|
||||||
).set_display_name('BaseModelMetadataResolver')
|
).set_display_name('Resolve Model Metadata')
|
||||||
|
|
||||||
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||||
image_name='text_importer'
|
image_name='text_importer'
|
||||||
).set_display_name('PromptDatasetImageUriResolver')
|
).set_display_name('Resolve Prompt Dataset Image URI')
|
||||||
prompt_dataset_importer = (
|
prompt_dataset_importer = (
|
||||||
private_text_importer.PrivateTextImporter(
|
private_text_importer.PrivateTextImporter(
|
||||||
project=project,
|
project=project,
|
||||||
|
|
@ -117,13 +117,13 @@ def rlhf_pipeline(
|
||||||
image_uri=prompt_dataset_image_uri.output,
|
image_uri=prompt_dataset_image_uri.output,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
)
|
)
|
||||||
.set_display_name('PromptDatasetImporter')
|
.set_display_name('Import Prompt Dataset')
|
||||||
.set_caching_options(False)
|
.set_caching_options(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
preference_dataset_image_uri = function_based.resolve_private_image_uri(
|
preference_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||||
image_name='text_comparison_importer'
|
image_name='text_comparison_importer'
|
||||||
).set_display_name('PreferenceDatasetImageUriResolver')
|
).set_display_name('Resolve Preference Dataset Image URI')
|
||||||
comma_separated_candidates_field_names = (
|
comma_separated_candidates_field_names = (
|
||||||
function_based.convert_to_delimited_string(items=candidate_columns)
|
function_based.convert_to_delimited_string(items=candidate_columns)
|
||||||
)
|
)
|
||||||
|
|
@ -142,7 +142,7 @@ def rlhf_pipeline(
|
||||||
image_uri=preference_dataset_image_uri.output,
|
image_uri=preference_dataset_image_uri.output,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
)
|
)
|
||||||
.set_display_name('PreferenceDatasetImporter')
|
.set_display_name('Import Preference Dataset')
|
||||||
.set_caching_options(False)
|
.set_caching_options(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ def rlhf_pipeline(
|
||||||
image_name='reward_model',
|
image_name='reward_model',
|
||||||
accelerator_type=machine_spec.outputs['accelerator_type'],
|
accelerator_type=machine_spec.outputs['accelerator_type'],
|
||||||
accelerator_count=machine_spec.outputs['accelerator_count'],
|
accelerator_count=machine_spec.outputs['accelerator_count'],
|
||||||
).set_display_name('RewardModelImageUriResolver')
|
).set_display_name('Resolve Reward Model Image URI')
|
||||||
reward_model = (
|
reward_model = (
|
||||||
reward_model_trainer.RewardModelTrainer(
|
reward_model_trainer.RewardModelTrainer(
|
||||||
project=project,
|
project=project,
|
||||||
|
|
@ -175,13 +175,13 @@ def rlhf_pipeline(
|
||||||
learning_rate_multiplier=reward_model_learning_rate_multiplier,
|
learning_rate_multiplier=reward_model_learning_rate_multiplier,
|
||||||
lora_dim=reward_model_lora_dim,
|
lora_dim=reward_model_lora_dim,
|
||||||
)
|
)
|
||||||
.set_display_name('RewardModelTrainer')
|
.set_display_name('Reward Model Trainer')
|
||||||
.set_caching_options(False)
|
.set_caching_options(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
has_tensorboard_id = function_based.value_exists(
|
has_tensorboard_id = function_based.value_exists(
|
||||||
value=tensorboard_resource_id
|
value=tensorboard_resource_id
|
||||||
)
|
).set_display_name('Resolve Tensorboard Resource ID')
|
||||||
with kfp.dsl.Condition( # pytype: disable=wrong-arg-types
|
with kfp.dsl.Condition( # pytype: disable=wrong-arg-types
|
||||||
has_tensorboard_id.output == True, # pylint: disable=singleton-comparison, g-explicit-bool-comparison
|
has_tensorboard_id.output == True, # pylint: disable=singleton-comparison, g-explicit-bool-comparison
|
||||||
name='Upload Reward Model Tensorboard Metrics',
|
name='Upload Reward Model Tensorboard Metrics',
|
||||||
|
|
@ -194,13 +194,13 @@ def rlhf_pipeline(
|
||||||
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
|
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
|
||||||
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
||||||
),
|
),
|
||||||
)
|
).set_display_name('Reward Model Tensorboard Metrics Uploader')
|
||||||
|
|
||||||
rl_image_uri = function_based.resolve_private_image_uri(
|
rl_image_uri = function_based.resolve_private_image_uri(
|
||||||
image_name='reinforcer',
|
image_name='reinforcer',
|
||||||
accelerator_type=machine_spec.outputs['accelerator_type'],
|
accelerator_type=machine_spec.outputs['accelerator_type'],
|
||||||
accelerator_count=machine_spec.outputs['accelerator_count'],
|
accelerator_count=machine_spec.outputs['accelerator_count'],
|
||||||
).set_display_name('ReinforcerImageUriResolver')
|
).set_display_name('Resolve Reinforcer Image URI')
|
||||||
rl_model = (
|
rl_model = (
|
||||||
reinforcer.Reinforcer(
|
reinforcer.Reinforcer(
|
||||||
project=project,
|
project=project,
|
||||||
|
|
@ -246,9 +246,11 @@ def rlhf_pipeline(
|
||||||
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
|
f'{kfp.dsl.PIPELINE_JOB_ID_PLACEHOLDER}-'
|
||||||
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
||||||
),
|
),
|
||||||
)
|
).set_display_name('Reinforcement Learning Tensorboard Metrics Uploader')
|
||||||
|
|
||||||
should_perform_inference = function_based.value_exists(value=eval_dataset)
|
should_perform_inference = function_based.value_exists(
|
||||||
|
value=eval_dataset
|
||||||
|
).set_display_name('Resolve Inference Dataset')
|
||||||
with kfp.dsl.Condition(
|
with kfp.dsl.Condition(
|
||||||
should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison
|
should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison
|
||||||
):
|
):
|
||||||
|
|
@ -266,22 +268,23 @@ def rlhf_pipeline(
|
||||||
adapter_artifact = kfp.dsl.importer(
|
adapter_artifact = kfp.dsl.importer(
|
||||||
artifact_uri=rl_model.outputs['output_adapter_path'],
|
artifact_uri=rl_model.outputs['output_adapter_path'],
|
||||||
artifact_class=kfp.dsl.Artifact,
|
artifact_class=kfp.dsl.Artifact,
|
||||||
)
|
).set_display_name('Import Tuned Adapter')
|
||||||
regional_endpoint = function_based.resolve_regional_endpoint(
|
regional_endpoint = function_based.resolve_regional_endpoint(
|
||||||
upload_location=upload_location
|
upload_location=upload_location
|
||||||
)
|
).set_display_name('Resolve Regional Endpoint')
|
||||||
display_name = function_based.resolve_model_display_name(
|
display_name = function_based.resolve_model_display_name(
|
||||||
large_model_reference=reference_model_metadata.outputs[
|
large_model_reference=reference_model_metadata.outputs[
|
||||||
'large_model_reference'
|
'large_model_reference'
|
||||||
],
|
],
|
||||||
model_display_name=model_display_name,
|
model_display_name=model_display_name,
|
||||||
)
|
).set_display_name('Resolve Model Display Name')
|
||||||
upload_model = function_based.resolve_upload_model(
|
upload_model = function_based.resolve_upload_model(
|
||||||
large_model_reference=reference_model_metadata.outputs[
|
large_model_reference=reference_model_metadata.outputs[
|
||||||
'large_model_reference'
|
'large_model_reference'
|
||||||
]
|
]
|
||||||
)
|
).set_display_name('Resolve Upload Model')
|
||||||
upload_task = upload_llm_model.upload_llm_model(
|
upload_task = (
|
||||||
|
upload_llm_model.upload_llm_model(
|
||||||
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
||||||
location=upload_location,
|
location=upload_location,
|
||||||
regional_endpoint=regional_endpoint.output,
|
regional_endpoint=regional_endpoint.output,
|
||||||
|
|
@ -289,16 +292,19 @@ def rlhf_pipeline(
|
||||||
model_display_name=display_name.output,
|
model_display_name=display_name.output,
|
||||||
model_reference_name='text-bison@001',
|
model_reference_name='text-bison@001',
|
||||||
upload_model=upload_model.output,
|
upload_model=upload_model.output,
|
||||||
).set_env_variable(
|
)
|
||||||
|
.set_env_variable(
|
||||||
name='VERTEX_AI_PIPELINES_RUN_LABELS',
|
name='VERTEX_AI_PIPELINES_RUN_LABELS',
|
||||||
value=json.dumps({'tune-type': 'rlhf'}),
|
value=json.dumps({'tune-type': 'rlhf'}),
|
||||||
)
|
)
|
||||||
|
.set_display_name('Upload Model')
|
||||||
|
)
|
||||||
deploy_model = function_based.resolve_deploy_model(
|
deploy_model = function_based.resolve_deploy_model(
|
||||||
deploy_model=deploy_model,
|
deploy_model=deploy_model,
|
||||||
large_model_reference=reference_model_metadata.outputs[
|
large_model_reference=reference_model_metadata.outputs[
|
||||||
'large_model_reference'
|
'large_model_reference'
|
||||||
],
|
],
|
||||||
)
|
).set_display_name('Resolve Deploy Model')
|
||||||
deploy_task = deploy_llm_model.create_endpoint_and_deploy_model(
|
deploy_task = deploy_llm_model.create_endpoint_and_deploy_model(
|
||||||
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
||||||
location=upload_location,
|
location=upload_location,
|
||||||
|
|
@ -306,7 +312,7 @@ def rlhf_pipeline(
|
||||||
display_name=display_name.output,
|
display_name=display_name.output,
|
||||||
regional_endpoint=regional_endpoint.output,
|
regional_endpoint=regional_endpoint.output,
|
||||||
deploy_model=deploy_model.output,
|
deploy_model=deploy_model.output,
|
||||||
)
|
).set_display_name('Deploy Model')
|
||||||
|
|
||||||
return PipelineOutput(
|
return PipelineOutput(
|
||||||
model_resource_name=upload_task.outputs['model_resource_name'],
|
model_resource_name=upload_task.outputs['model_resource_name'],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue