feat(components): Update supported large model reference names that can be resolved by function based component in _implementation/llm
PiperOrigin-RevId: 559493244
This commit is contained in:
parent
f43272dee8
commit
9ce2866527
|
|
@ -13,225 +13,12 @@
|
|||
# limitations under the License.
|
||||
"""KFP Container component that performs bulk inference."""
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from google_cloud_pipeline_components import _image
|
||||
from google_cloud_pipeline_components import utils as gcpc_utils
|
||||
from google_cloud_pipeline_components._implementation.llm import utils
|
||||
import kfp
|
||||
|
||||
|
||||
@kfp.dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
||||
def get_default_bulk_inference_machine_specs(
|
||||
large_model_reference: str,
|
||||
use_gpu_defaults: bool = False,
|
||||
accelerator_type_override: Optional[str] = None,
|
||||
accelerator_count_override: Optional[int] = None,
|
||||
) -> NamedTuple(
|
||||
'MachineSpec', accelerator_type=str, accelerator_count=int, machine_type=str
|
||||
):
|
||||
"""Gets default machine specs for bulk inference and overrides params if provided.
|
||||
|
||||
Args:
|
||||
large_model_reference: Foundational model to use for default specs.
|
||||
use_gpu_defaults: Whether to get default gpu specs (otherwise will get TPU
|
||||
specs).
|
||||
accelerator_type_override: Accelerator type to override the default.
|
||||
accelerator_count_override: Accelerator count to override the default.
|
||||
|
||||
Returns:
|
||||
MachineSpec, including accelerator_type, accelerator_count, machine_type.
|
||||
|
||||
Raises:
|
||||
ValueError: If large_model_reference is invalid or overridden values are
|
||||
invalid.
|
||||
"""
|
||||
# pylint: disable=g-import-not-at-top,redefined-outer-name,reimported
|
||||
import collections
|
||||
# pylint: enable=g-import-not-at-top,redefined-outer-name,reimported
|
||||
|
||||
machine_spec = collections.namedtuple(
|
||||
'MachineSpec', ['accelerator_type', 'accelerator_count', 'machine_type']
|
||||
)
|
||||
|
||||
# machine types
|
||||
cloud_tpu = 'cloud-tpu'
|
||||
ultra_gpu_1g = 'a2-ultragpu-1g'
|
||||
ultra_gpu_2g = 'a2-ultragpu-2g'
|
||||
ultra_gpu_4g = 'a2-ultragpu-4g'
|
||||
ultra_gpu_8g = 'a2-ultragpu-8g'
|
||||
high_gpu_1g = 'a2-highgpu-1g'
|
||||
high_gpu_2g = 'a2-highgpu-2g'
|
||||
high_gpu_4g = 'a2-highgpu-4g'
|
||||
high_gpu_8g = 'a2-highgpu-8g'
|
||||
mega_gpu_16g = 'a2-megagpu-16g'
|
||||
|
||||
# accelerator types
|
||||
tpu_v2 = 'TPU_V2'
|
||||
tpu_v3 = 'TPU_V3'
|
||||
nvidia_a100_40g = 'NVIDIA_TESLA_A100'
|
||||
nvidia_a100_80g = 'NVIDIA_A100_80GB'
|
||||
tpu_accelerator_types = frozenset([tpu_v2, tpu_v3])
|
||||
gpu_accelerator_types = frozenset([nvidia_a100_40g, nvidia_a100_80g])
|
||||
valid_accelerator_types = frozenset(
|
||||
list(gpu_accelerator_types) + list(tpu_accelerator_types)
|
||||
)
|
||||
|
||||
# base models
|
||||
palm_tiny = 'PALM_TINY'
|
||||
gecko = 'GECKO'
|
||||
otter = 'OTTER'
|
||||
bison = 'BISON'
|
||||
elephant = 'ELEPHANT'
|
||||
t5_small = 'T5_SMALL'
|
||||
t5_large = 'T5_LARGE'
|
||||
t5_xl = 'T5_XL'
|
||||
t5_xxl = 'T5_XXL'
|
||||
|
||||
def _get_machine_type(accelerator_type: str, accelerator_count: int) -> str:
|
||||
if accelerator_count < 1:
|
||||
raise ValueError('accelerator_count must be at least 1.')
|
||||
|
||||
if accelerator_type in tpu_accelerator_types:
|
||||
return cloud_tpu
|
||||
|
||||
elif accelerator_type == nvidia_a100_40g:
|
||||
if accelerator_count == 1:
|
||||
return high_gpu_1g
|
||||
|
||||
elif accelerator_count == 2:
|
||||
return high_gpu_2g
|
||||
|
||||
elif accelerator_count <= 4:
|
||||
return high_gpu_4g
|
||||
|
||||
elif accelerator_count <= 8:
|
||||
return high_gpu_8g
|
||||
|
||||
elif accelerator_count <= 16:
|
||||
return mega_gpu_16g
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Too many {accelerator_type} requested. Must be <= 16.'
|
||||
)
|
||||
|
||||
elif accelerator_type == nvidia_a100_80g:
|
||||
if accelerator_count == 1:
|
||||
return ultra_gpu_1g
|
||||
|
||||
elif accelerator_count == 2:
|
||||
return ultra_gpu_2g
|
||||
|
||||
elif accelerator_count <= 4:
|
||||
return ultra_gpu_4g
|
||||
|
||||
elif accelerator_count <= 8:
|
||||
return ultra_gpu_8g
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Too many {accelerator_type} requested. Must be <= 8.'
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
'accelerator_type_override must be one of'
|
||||
f' {sorted(valid_accelerator_types)}.'
|
||||
)
|
||||
|
||||
accepted_reference_models = frozenset(
|
||||
[palm_tiny, gecko, otter, bison, elephant, t5_small, t5_xxl]
|
||||
)
|
||||
|
||||
# Default GPU specs are based on study here:
|
||||
# https://docs.google.com/spreadsheets/d/1_ZKqfyLQ5vYrOQH5kfdMb_OoNT48r6vNbqv3dKDxDTw/edit?resourcekey=0-3kgDrn4XDdvlJAc8Kils-Q#gid=255356424
|
||||
reference_model_to_model_specs_gpu = {
|
||||
palm_tiny: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=1,
|
||||
machine_type=high_gpu_1g,
|
||||
),
|
||||
gecko: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=1,
|
||||
machine_type=high_gpu_1g,
|
||||
),
|
||||
otter: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=2,
|
||||
machine_type=high_gpu_2g,
|
||||
),
|
||||
bison: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=8,
|
||||
machine_type=high_gpu_8g,
|
||||
),
|
||||
elephant: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=8,
|
||||
machine_type=high_gpu_8g,
|
||||
),
|
||||
t5_small: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=1,
|
||||
machine_type=high_gpu_1g,
|
||||
),
|
||||
t5_large: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=1,
|
||||
machine_type=high_gpu_1g,
|
||||
),
|
||||
t5_xl: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=1,
|
||||
machine_type=high_gpu_1g,
|
||||
),
|
||||
t5_xxl: machine_spec(
|
||||
accelerator_type=nvidia_a100_40g,
|
||||
accelerator_count=2,
|
||||
machine_type=high_gpu_2g,
|
||||
),
|
||||
}
|
||||
|
||||
# Get defaults
|
||||
if large_model_reference not in accepted_reference_models:
|
||||
raise ValueError(
|
||||
'large_model_reference must be one of'
|
||||
f' {sorted(accepted_reference_models)}.'
|
||||
)
|
||||
|
||||
if use_gpu_defaults:
|
||||
default_machine_spec = reference_model_to_model_specs_gpu[
|
||||
large_model_reference
|
||||
]
|
||||
|
||||
else:
|
||||
# This is the only config available for TPUs in our shared reservation pool.
|
||||
default_machine_spec = machine_spec(
|
||||
accelerator_type=tpu_v3,
|
||||
accelerator_count=32,
|
||||
machine_type=cloud_tpu,
|
||||
)
|
||||
|
||||
# Override default behavior we defer validations of these to the resource
|
||||
# provisioner.
|
||||
if any([accelerator_type_override, accelerator_count_override]):
|
||||
if not all([accelerator_type_override, accelerator_count_override]):
|
||||
raise ValueError('Accelerator type and count must both be set.')
|
||||
accelerator_type = accelerator_type_override
|
||||
accelerator_count = accelerator_count_override
|
||||
else:
|
||||
accelerator_type = default_machine_spec.accelerator_type
|
||||
accelerator_count = default_machine_spec.accelerator_count
|
||||
|
||||
return machine_spec(
|
||||
accelerator_type,
|
||||
accelerator_count,
|
||||
_get_machine_type(accelerator_type, accelerator_count),
|
||||
)
|
||||
|
||||
|
||||
@kfp.dsl.container_component
|
||||
def BulkInferrer( # pylint: disable=invalid-name
|
||||
project: str,
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ def resolve_reference_model_metadata(
|
|||
large_model_reference: str,
|
||||
reference_model_path: Optional[str] = None,
|
||||
) -> NamedTuple(
|
||||
'BaseModelMetadata',
|
||||
'Outputs',
|
||||
large_model_reference=str,
|
||||
reference_model_path=str,
|
||||
reward_model_reference=str,
|
||||
|
|
@ -181,75 +181,160 @@ def resolve_reference_model_metadata(
|
|||
Raises:
|
||||
ValueError: if no metadata exists for the given base model.
|
||||
"""
|
||||
reference_model_metadata = NamedTuple(
|
||||
'ReferenceModelMetadata',
|
||||
large_model_reference=str,
|
||||
reference_model_path=str,
|
||||
reward_model_reference=str,
|
||||
reward_model_path=str,
|
||||
is_supported=bool,
|
||||
)
|
||||
|
||||
reference_models = {
|
||||
't5-small': reference_model_metadata(
|
||||
large_model_reference='T5_SMALL',
|
||||
reference_model_path=(
|
||||
'gs://t5-data/pretrained_models/t5x/flan_t5_small/'
|
||||
),
|
||||
reward_model_reference='T5_SMALL',
|
||||
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_small',
|
||||
is_supported=True,
|
||||
),
|
||||
't5-large': reference_model_metadata(
|
||||
large_model_reference='T5_LARGE',
|
||||
reference_model_path=(
|
||||
'gs://t5-data/pretrained_models/t5x/flan_t5_large/'
|
||||
),
|
||||
reward_model_reference='T5_LARGE',
|
||||
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_large',
|
||||
is_supported=True,
|
||||
),
|
||||
't5-xl': reference_model_metadata(
|
||||
large_model_reference='T5_XL',
|
||||
reference_model_path='gs://t5-data/pretrained_models/t5x/flan_t5_xl/',
|
||||
reward_model_reference='T5_XL',
|
||||
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_xl',
|
||||
is_supported=True,
|
||||
),
|
||||
't5-xxl': reference_model_metadata(
|
||||
large_model_reference='T5_XXL',
|
||||
reference_model_path=(
|
||||
'gs://t5-data/pretrained_models/t5x/flan_t5_xxl/'
|
||||
),
|
||||
reward_model_reference='T5_XXL',
|
||||
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_xxl',
|
||||
is_supported=True,
|
||||
),
|
||||
'palm-tiny': reference_model_metadata(
|
||||
large_model_reference='PALM_TINY',
|
||||
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny/',
|
||||
reward_model_reference='PALM_TINY',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny/',
|
||||
is_supported=False,
|
||||
),
|
||||
'gecko': reference_model_metadata(
|
||||
large_model_reference='GECKO',
|
||||
reference_model_path=(
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko/'
|
||||
),
|
||||
reward_model_reference='GECKO',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko_pretrain/',
|
||||
is_supported=False,
|
||||
),
|
||||
'otter': reference_model_metadata(
|
||||
large_model_reference='OTTER',
|
||||
reference_model_path=(
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter/'
|
||||
),
|
||||
reward_model_reference='OTTER',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
|
||||
is_supported=False,
|
||||
),
|
||||
'bison': reference_model_metadata(
|
||||
large_model_reference='BISON',
|
||||
reference_model_path=(
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
|
||||
),
|
||||
reward_model_reference='OTTER',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
|
||||
is_supported=False, # Deprecated: Use text-bision@001 instead.
|
||||
),
|
||||
'text-bison@001': reference_model_metadata(
|
||||
large_model_reference='BISON',
|
||||
reference_model_path=(
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
|
||||
),
|
||||
reward_model_reference='OTTER',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
|
||||
is_supported=True,
|
||||
),
|
||||
'elephant': reference_model_metadata(
|
||||
large_model_reference='ELEPHANT',
|
||||
reference_model_path=(
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_elephant/'
|
||||
),
|
||||
reward_model_reference='OTTER',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
|
||||
is_supported=False,
|
||||
),
|
||||
'llama-2-7b': reference_model_metadata(
|
||||
large_model_reference='LLAMA_2_7B',
|
||||
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
|
||||
reward_model_reference='LLAMA_2_7B',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
|
||||
is_supported=True,
|
||||
),
|
||||
'llama-2-13b': reference_model_metadata(
|
||||
large_model_reference='LLAMA_2_13B',
|
||||
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
|
||||
reward_model_reference='LLAMA_2_13B',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
|
||||
is_supported=True,
|
||||
),
|
||||
'llama-2-7b-chat': reference_model_metadata(
|
||||
large_model_reference='LLAMA_2_7B_CHAT',
|
||||
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b_chat/',
|
||||
reward_model_reference='LLAMA_2_7B_CHAT',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b_chat/',
|
||||
is_supported=True,
|
||||
),
|
||||
'llama-2-13b-chat': reference_model_metadata(
|
||||
large_model_reference='LLAMA_2_13B_CHAT',
|
||||
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b_chat/',
|
||||
reward_model_reference='LLAMA_2_13B_CHAT',
|
||||
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b_chat/',
|
||||
is_supported=True,
|
||||
),
|
||||
}
|
||||
|
||||
reference_model_key = large_model_reference.lower().replace('_', '-')
|
||||
if reference_model_key not in reference_models:
|
||||
supported_models = [
|
||||
k for k, v in reference_models.items() if v.is_supported
|
||||
]
|
||||
raise ValueError(
|
||||
f'Unknown reference model {large_model_reference}.'
|
||||
' large_model_reference must be one of'
|
||||
f' {sorted(supported_models)}.'
|
||||
)
|
||||
|
||||
reference_model = reference_models[reference_model_key]
|
||||
|
||||
# TODO(latture): Move this logic to a container component and use
|
||||
# PredefinedModels enum to resolve model paths.
|
||||
outputs = NamedTuple(
|
||||
'BaseModelMetadata',
|
||||
'Outputs',
|
||||
large_model_reference=str,
|
||||
reference_model_path=str,
|
||||
reward_model_reference=str,
|
||||
reward_model_path=str,
|
||||
)
|
||||
reference_model_key = large_model_reference.upper().replace('-', '_')
|
||||
predefined_model_paths = {
|
||||
'PALM_TINY': (
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny/'
|
||||
),
|
||||
'GECKO': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko/',
|
||||
'OTTER': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter/',
|
||||
'BISON': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/',
|
||||
'ELEPHANT': (
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_elephant/'
|
||||
),
|
||||
'T5_SMALL': 'gs://t5-data/pretrained_models/t5x/flan_t5_small/',
|
||||
'T5_LARGE': 'gs://t5-data/pretrained_models/t5x/flan_t5_large/',
|
||||
'T5_XL': 'gs://t5-data/pretrained_models/t5x/flan_t5_xl/',
|
||||
'T5_XXL': 'gs://t5-data/pretrained_models/t5x/flan_t5_xxl/',
|
||||
}
|
||||
predefined_reward_model_paths = {
|
||||
'PALM_TINY': (
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny'
|
||||
),
|
||||
'GECKO': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko_pretrain',
|
||||
'OTTER': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain',
|
||||
'ELEPHANT': (
|
||||
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_elephant/'
|
||||
),
|
||||
'T5_SMALL': 'gs://t5-data/pretrained_models/t5x/t5_1_1_small',
|
||||
'T5_LARGE': 'gs://t5-data/pretrained_models/t5x/t5_1_1_large',
|
||||
'T5_XL': 'gs://t5-data/pretrained_models/t5x/t5_1_1_xl',
|
||||
'T5_XXL': 'gs://t5-data/pretrained_models/t5x/t5_1_1_xxl',
|
||||
}
|
||||
|
||||
if reference_model_key not in predefined_model_paths:
|
||||
raise ValueError(
|
||||
f'No metadata found for `{reference_model_key}`. '
|
||||
f'Base model must be one of {list(predefined_model_paths.keys())}.'
|
||||
)
|
||||
|
||||
# Mapping from base model to its corresponding reward model.
|
||||
reference_model_to_reward_model = {
|
||||
'PALM_TINY': 'PALM_TINY',
|
||||
'GECKO': 'GECKO',
|
||||
'OTTER': 'OTTER',
|
||||
'BISON': 'OTTER',
|
||||
'ELEPHANT': 'ELEPHANT',
|
||||
'T5_SMALL': 'T5_SMALL',
|
||||
'T5_LARGE': 'T5_LARGE',
|
||||
'T5_XL': 'T5_XL',
|
||||
'T5_XXL': 'T5_XXL',
|
||||
}
|
||||
|
||||
reward_model_key = reference_model_to_reward_model[reference_model_key]
|
||||
|
||||
return outputs(
|
||||
large_model_reference=reference_model_key,
|
||||
large_model_reference=reference_model.large_model_reference,
|
||||
reference_model_path=(
|
||||
reference_model_path or predefined_model_paths[reference_model_key]
|
||||
reference_model_path or reference_model.reference_model_path
|
||||
),
|
||||
reward_model_reference=reward_model_key,
|
||||
reward_model_path=predefined_reward_model_paths[reward_model_key],
|
||||
reward_model_reference=reference_model.reward_model_reference,
|
||||
reward_model_path=reference_model.reward_model_path,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ def infer_pipeline(
|
|||
|
||||
Args:
|
||||
large_model_reference: Name of the base model. Supported values are
|
||||
``BISON``, ``T5_SMALL``, ``T5_LARGE``, ``T5_XL``, and ``T5_XXL``.
|
||||
``BISON`` and ``T5_SMALL`` are supported in ``us-central1` and
|
||||
``europe-west4``. ``T5_LARGE``, ``T5_XL`` and ``T5_XXL`` are only
|
||||
``text-bison@001``, ``t5-small``, ``t5-large``, ``t5-xl`` and ``t5-xxl``.
|
||||
``text-bison@001`` and ``t5-small`` are supported in ``us-central1` and
|
||||
``europe-west4``. ``t5-large``, ``t5-xl`` and ``t5-xxl`` are only
|
||||
supported in ``europe-west4``.
|
||||
model_checkpoint: Cloud storage path to the model checkpoint.
|
||||
prompt_dataset: Cloud storage path to an unlabled prompt dataset used for
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ def rlhf_pipeline(
|
|||
the prompt, ``candidate_0`` and ``candidate_1`` that contain candidate
|
||||
responses, ``choice`` that specifies the preferred candidate.
|
||||
large_model_reference: Name of the base model. Supported values are
|
||||
``BISON``, ``T5_SMALL``, ``T5_LARGE``, ``T5_XL``, and ``T5_XXL``.
|
||||
``BISON`` and ``T5_SMALL`` are supported in ``us-central1` and
|
||||
``europe-west4``. ``T5_LARGE``, ``T5_XL`` and ``T5_XXL`` are only
|
||||
``text-bison@001``, ``t5-small``, ``t5-large``, ``t5-xl`` and ``t5-xxl``.
|
||||
``text-bison@001`` and ``t5-small`` are supported in ``us-central1` and
|
||||
``europe-west4``. ``t5-large``, ``t5-xl`` and ``t5-xxl`` are only
|
||||
supported in ``europe-west4``.
|
||||
model_display_name: Name of the fine-tuned model shown in the Model
|
||||
Registry. If not provided, a default name will be created.
|
||||
|
|
|
|||
Loading…
Reference in New Issue