feat(components): Make `model_checkpoint` optional for `preview.llm.infer_pipeline`

PiperOrigin-RevId: 574876480
This commit is contained in:
Googler 2023-10-19 08:14:21 -07:00 committed by Google Cloud Pipeline Components maintainers
parent d8a0660df5
commit e8fb6990df
2 changed files with 6 additions and 4 deletions

View File

@ -6,6 +6,7 @@
* Add sliced evaluation metrics support for custom and unstructured AutoML models in evaluation pipeline and evaluation pipeline with feature attribution.
* Support `service_account` in `ModelBatchPredictOp`.
* Release `DataflowFlexTemplateJobOp` to GA namespace (`v1.dataflow.DataflowFlexTemplateJobOp`).
* Make `model_checkpoint` optional for `preview.llm.infer_pipeline`. If not provided, the base model associated with the `large_model_reference` will be used.
## Release 2.4.1
* Disable caching for LLM pipeline tasks that store temporary artifacts.

View File

@ -33,8 +33,8 @@ PipelineOutput = NamedTuple('Outputs', output_prediction_gcs_path=str)
)
def infer_pipeline(
large_model_reference: str,
model_checkpoint: str,
prompt_dataset: str,
model_checkpoint: Optional[str] = None,
prompt_sequence_length: int = 512,
target_sequence_length: int = 64,
sampling_strategy: str = 'greedy',
@ -47,7 +47,7 @@ def infer_pipeline(
Args:
large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
model_checkpoint: Cloud storage path to the model checkpoint.
model_checkpoint: Optional Cloud storage path to the model checkpoint. If not provided, the default checkpoint for the `large_model_reference` will be used.
prompt_dataset: Cloud storage path to an unlabled prompt dataset used for reinforcement learning. The dataset format is jsonl. Each example in the dataset must have an `input_text` field that contains the prompt.
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
@ -66,7 +66,8 @@ def infer_pipeline(
use_test_spec=env.get_use_test_machine_spec(),
).set_display_name('Resolve Machine Spec')
reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference
large_model_reference=large_model_reference,
reference_model_path=model_checkpoint,
).set_display_name('Resolve Model Metadata')
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
@ -98,7 +99,7 @@ def infer_pipeline(
bulk_inference = bulk_inferrer.BulkInferrer(
project=project,
location=location,
input_model=model_checkpoint,
input_model=reference_model_metadata.outputs['reference_model_path'],
input_dataset_path=prompt_dataset_importer.outputs['imported_data_path'],
dataset_split=env.TRAIN_SPLIT,
inputs_sequence_length=prompt_sequence_length,