feat(components): Make `model_checkpoint` optional for `preview.llm.infer_pipeline`
PiperOrigin-RevId: 574876480
This commit is contained in:
parent
d8a0660df5
commit
e8fb6990df
|
@ -6,6 +6,7 @@
|
|||
* Add sliced evaluation metrics support for custom and unstructured AutoML models in evaluation pipeline and evaluation pipeline with feature attribution.
|
||||
* Support `service_account` in `ModelBatchPredictOp`.
|
||||
* Release `DataflowFlexTemplateJobOp` to GA namespace (`v1.dataflow.DataflowFlexTemplateJobOp`).
|
||||
* Make `model_checkpoint` optional for `preview.llm.infer_pipeline`. If not provided, the base model associated with the `large_model_reference` will be used.
|
||||
|
||||
## Release 2.4.1
|
||||
* Disable caching for LLM pipeline tasks that store temporary artifacts.
|
||||
|
|
|
@ -33,8 +33,8 @@ PipelineOutput = NamedTuple('Outputs', output_prediction_gcs_path=str)
|
|||
)
|
||||
def infer_pipeline(
|
||||
large_model_reference: str,
|
||||
model_checkpoint: str,
|
||||
prompt_dataset: str,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
prompt_sequence_length: int = 512,
|
||||
target_sequence_length: int = 64,
|
||||
sampling_strategy: str = 'greedy',
|
||||
|
@ -47,7 +47,7 @@ def infer_pipeline(
|
|||
|
||||
Args:
|
||||
large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
|
||||
model_checkpoint: Cloud storage path to the model checkpoint.
|
||||
model_checkpoint: Optional Cloud storage path to the model checkpoint. If not provided, the default checkpoint for the `large_model_reference` will be used.
|
||||
prompt_dataset: Cloud storage path to an unlabled prompt dataset used for reinforcement learning. The dataset format is jsonl. Each example in the dataset must have an `input_text` field that contains the prompt.
|
||||
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
|
||||
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
|
||||
|
@ -66,7 +66,8 @@ def infer_pipeline(
|
|||
use_test_spec=env.get_use_test_machine_spec(),
|
||||
).set_display_name('Resolve Machine Spec')
|
||||
reference_model_metadata = function_based.resolve_reference_model_metadata(
|
||||
large_model_reference=large_model_reference
|
||||
large_model_reference=large_model_reference,
|
||||
reference_model_path=model_checkpoint,
|
||||
).set_display_name('Resolve Model Metadata')
|
||||
|
||||
prompt_dataset_image_uri = function_based.resolve_private_image_uri(
|
||||
|
@ -98,7 +99,7 @@ def infer_pipeline(
|
|||
bulk_inference = bulk_inferrer.BulkInferrer(
|
||||
project=project,
|
||||
location=location,
|
||||
input_model=model_checkpoint,
|
||||
input_model=reference_model_metadata.outputs['reference_model_path'],
|
||||
input_dataset_path=prompt_dataset_importer.outputs['imported_data_path'],
|
||||
dataset_split=env.TRAIN_SPLIT,
|
||||
inputs_sequence_length=prompt_sequence_length,
|
||||
|
|
Loading…
Reference in New Issue