feat(components): Update supported large model reference names that can be resolved by function based component in _implementation/llm

PiperOrigin-RevId: 559493244
This commit is contained in:
Googler 2023-08-23 11:35:37 -07:00 committed by Google Cloud Pipeline Components maintainers
parent f43272dee8
commit 9ce2866527
4 changed files with 150 additions and 278 deletions

View File

@ -13,225 +13,12 @@
# limitations under the License.
"""KFP Container component that performs bulk inference."""
from typing import NamedTuple, Optional
from google_cloud_pipeline_components import _image
from google_cloud_pipeline_components import utils as gcpc_utils
from google_cloud_pipeline_components._implementation.llm import utils
import kfp
@kfp.dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
def get_default_bulk_inference_machine_specs(
large_model_reference: str,
use_gpu_defaults: bool = False,
accelerator_type_override: Optional[str] = None,
accelerator_count_override: Optional[int] = None,
) -> NamedTuple(
'MachineSpec', accelerator_type=str, accelerator_count=int, machine_type=str
):
"""Gets default machine specs for bulk inference and overrides params if provided.
Args:
large_model_reference: Foundational model to use for default specs.
use_gpu_defaults: Whether to get default gpu specs (otherwise will get TPU
specs).
accelerator_type_override: Accelerator type to override the default.
accelerator_count_override: Accelerator count to override the default.
Returns:
MachineSpec, including accelerator_type, accelerator_count, machine_type.
Raises:
ValueError: If large_model_reference is invalid or overridden values are
invalid.
"""
# pylint: disable=g-import-not-at-top,redefined-outer-name,reimported
import collections
# pylint: enable=g-import-not-at-top,redefined-outer-name,reimported
machine_spec = collections.namedtuple(
'MachineSpec', ['accelerator_type', 'accelerator_count', 'machine_type']
)
# machine types
cloud_tpu = 'cloud-tpu'
ultra_gpu_1g = 'a2-ultragpu-1g'
ultra_gpu_2g = 'a2-ultragpu-2g'
ultra_gpu_4g = 'a2-ultragpu-4g'
ultra_gpu_8g = 'a2-ultragpu-8g'
high_gpu_1g = 'a2-highgpu-1g'
high_gpu_2g = 'a2-highgpu-2g'
high_gpu_4g = 'a2-highgpu-4g'
high_gpu_8g = 'a2-highgpu-8g'
mega_gpu_16g = 'a2-megagpu-16g'
# accelerator types
tpu_v2 = 'TPU_V2'
tpu_v3 = 'TPU_V3'
nvidia_a100_40g = 'NVIDIA_TESLA_A100'
nvidia_a100_80g = 'NVIDIA_A100_80GB'
tpu_accelerator_types = frozenset([tpu_v2, tpu_v3])
gpu_accelerator_types = frozenset([nvidia_a100_40g, nvidia_a100_80g])
valid_accelerator_types = frozenset(
list(gpu_accelerator_types) + list(tpu_accelerator_types)
)
# base models
palm_tiny = 'PALM_TINY'
gecko = 'GECKO'
otter = 'OTTER'
bison = 'BISON'
elephant = 'ELEPHANT'
t5_small = 'T5_SMALL'
t5_large = 'T5_LARGE'
t5_xl = 'T5_XL'
t5_xxl = 'T5_XXL'
def _get_machine_type(accelerator_type: str, accelerator_count: int) -> str:
if accelerator_count < 1:
raise ValueError('accelerator_count must be at least 1.')
if accelerator_type in tpu_accelerator_types:
return cloud_tpu
elif accelerator_type == nvidia_a100_40g:
if accelerator_count == 1:
return high_gpu_1g
elif accelerator_count == 2:
return high_gpu_2g
elif accelerator_count <= 4:
return high_gpu_4g
elif accelerator_count <= 8:
return high_gpu_8g
elif accelerator_count <= 16:
return mega_gpu_16g
else:
raise ValueError(
f'Too many {accelerator_type} requested. Must be <= 16.'
)
elif accelerator_type == nvidia_a100_80g:
if accelerator_count == 1:
return ultra_gpu_1g
elif accelerator_count == 2:
return ultra_gpu_2g
elif accelerator_count <= 4:
return ultra_gpu_4g
elif accelerator_count <= 8:
return ultra_gpu_8g
else:
raise ValueError(
f'Too many {accelerator_type} requested. Must be <= 8.'
)
else:
raise ValueError(
'accelerator_type_override must be one of'
f' {sorted(valid_accelerator_types)}.'
)
accepted_reference_models = frozenset(
[palm_tiny, gecko, otter, bison, elephant, t5_small, t5_xxl]
)
# Default GPU specs are based on study here:
# https://docs.google.com/spreadsheets/d/1_ZKqfyLQ5vYrOQH5kfdMb_OoNT48r6vNbqv3dKDxDTw/edit?resourcekey=0-3kgDrn4XDdvlJAc8Kils-Q#gid=255356424
reference_model_to_model_specs_gpu = {
palm_tiny: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=1,
machine_type=high_gpu_1g,
),
gecko: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=1,
machine_type=high_gpu_1g,
),
otter: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=2,
machine_type=high_gpu_2g,
),
bison: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=8,
machine_type=high_gpu_8g,
),
elephant: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=8,
machine_type=high_gpu_8g,
),
t5_small: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=1,
machine_type=high_gpu_1g,
),
t5_large: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=1,
machine_type=high_gpu_1g,
),
t5_xl: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=1,
machine_type=high_gpu_1g,
),
t5_xxl: machine_spec(
accelerator_type=nvidia_a100_40g,
accelerator_count=2,
machine_type=high_gpu_2g,
),
}
# Get defaults
if large_model_reference not in accepted_reference_models:
raise ValueError(
'large_model_reference must be one of'
f' {sorted(accepted_reference_models)}.'
)
if use_gpu_defaults:
default_machine_spec = reference_model_to_model_specs_gpu[
large_model_reference
]
else:
# This is the only config available for TPUs in our shared reservation pool.
default_machine_spec = machine_spec(
accelerator_type=tpu_v3,
accelerator_count=32,
machine_type=cloud_tpu,
)
# Override default behavior we defer validations of these to the resource
# provisioner.
if any([accelerator_type_override, accelerator_count_override]):
if not all([accelerator_type_override, accelerator_count_override]):
raise ValueError('Accelerator type and count must both be set.')
accelerator_type = accelerator_type_override
accelerator_count = accelerator_count_override
else:
accelerator_type = default_machine_spec.accelerator_type
accelerator_count = default_machine_spec.accelerator_count
return machine_spec(
accelerator_type,
accelerator_count,
_get_machine_type(accelerator_type, accelerator_count),
)
@kfp.dsl.container_component
def BulkInferrer( # pylint: disable=invalid-name
project: str,

View File

@ -160,7 +160,7 @@ def resolve_reference_model_metadata(
large_model_reference: str,
reference_model_path: Optional[str] = None,
) -> NamedTuple(
'BaseModelMetadata',
'Outputs',
large_model_reference=str,
reference_model_path=str,
reward_model_reference=str,
@ -181,75 +181,160 @@ def resolve_reference_model_metadata(
Raises:
ValueError: if no metadata exists for the given base model.
"""
reference_model_metadata = NamedTuple(
'ReferenceModelMetadata',
large_model_reference=str,
reference_model_path=str,
reward_model_reference=str,
reward_model_path=str,
is_supported=bool,
)
reference_models = {
't5-small': reference_model_metadata(
large_model_reference='T5_SMALL',
reference_model_path=(
'gs://t5-data/pretrained_models/t5x/flan_t5_small/'
),
reward_model_reference='T5_SMALL',
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_small',
is_supported=True,
),
't5-large': reference_model_metadata(
large_model_reference='T5_LARGE',
reference_model_path=(
'gs://t5-data/pretrained_models/t5x/flan_t5_large/'
),
reward_model_reference='T5_LARGE',
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_large',
is_supported=True,
),
't5-xl': reference_model_metadata(
large_model_reference='T5_XL',
reference_model_path='gs://t5-data/pretrained_models/t5x/flan_t5_xl/',
reward_model_reference='T5_XL',
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_xl',
is_supported=True,
),
't5-xxl': reference_model_metadata(
large_model_reference='T5_XXL',
reference_model_path=(
'gs://t5-data/pretrained_models/t5x/flan_t5_xxl/'
),
reward_model_reference='T5_XXL',
reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_xxl',
is_supported=True,
),
'palm-tiny': reference_model_metadata(
large_model_reference='PALM_TINY',
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny/',
reward_model_reference='PALM_TINY',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny/',
is_supported=False,
),
'gecko': reference_model_metadata(
large_model_reference='GECKO',
reference_model_path=(
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko/'
),
reward_model_reference='GECKO',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko_pretrain/',
is_supported=False,
),
'otter': reference_model_metadata(
large_model_reference='OTTER',
reference_model_path=(
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter/'
),
reward_model_reference='OTTER',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
is_supported=False,
),
'bison': reference_model_metadata(
large_model_reference='BISON',
reference_model_path=(
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
),
reward_model_reference='OTTER',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
is_supported=False, # Deprecated: Use text-bision@001 instead.
),
'text-bison@001': reference_model_metadata(
large_model_reference='BISON',
reference_model_path=(
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
),
reward_model_reference='OTTER',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
is_supported=True,
),
'elephant': reference_model_metadata(
large_model_reference='ELEPHANT',
reference_model_path=(
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_elephant/'
),
reward_model_reference='OTTER',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
is_supported=False,
),
'llama-2-7b': reference_model_metadata(
large_model_reference='LLAMA_2_7B',
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
reward_model_reference='LLAMA_2_7B',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
is_supported=True,
),
'llama-2-13b': reference_model_metadata(
large_model_reference='LLAMA_2_13B',
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
reward_model_reference='LLAMA_2_13B',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
is_supported=True,
),
'llama-2-7b-chat': reference_model_metadata(
large_model_reference='LLAMA_2_7B_CHAT',
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b_chat/',
reward_model_reference='LLAMA_2_7B_CHAT',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b_chat/',
is_supported=True,
),
'llama-2-13b-chat': reference_model_metadata(
large_model_reference='LLAMA_2_13B_CHAT',
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b_chat/',
reward_model_reference='LLAMA_2_13B_CHAT',
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b_chat/',
is_supported=True,
),
}
reference_model_key = large_model_reference.lower().replace('_', '-')
if reference_model_key not in reference_models:
supported_models = [
k for k, v in reference_models.items() if v.is_supported
]
raise ValueError(
f'Unknown reference model {large_model_reference}.'
' large_model_reference must be one of'
f' {sorted(supported_models)}.'
)
reference_model = reference_models[reference_model_key]
# TODO(latture): Move this logic to a container component and use
# PredefinedModels enum to resolve model paths.
outputs = NamedTuple(
'BaseModelMetadata',
'Outputs',
large_model_reference=str,
reference_model_path=str,
reward_model_reference=str,
reward_model_path=str,
)
reference_model_key = large_model_reference.upper().replace('-', '_')
predefined_model_paths = {
'PALM_TINY': (
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny/'
),
'GECKO': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko/',
'OTTER': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter/',
'BISON': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/',
'ELEPHANT': (
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_elephant/'
),
'T5_SMALL': 'gs://t5-data/pretrained_models/t5x/flan_t5_small/',
'T5_LARGE': 'gs://t5-data/pretrained_models/t5x/flan_t5_large/',
'T5_XL': 'gs://t5-data/pretrained_models/t5x/flan_t5_xl/',
'T5_XXL': 'gs://t5-data/pretrained_models/t5x/flan_t5_xxl/',
}
predefined_reward_model_paths = {
'PALM_TINY': (
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_palm_tiny'
),
'GECKO': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_gecko_pretrain',
'OTTER': 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain',
'ELEPHANT': (
'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_elephant/'
),
'T5_SMALL': 'gs://t5-data/pretrained_models/t5x/t5_1_1_small',
'T5_LARGE': 'gs://t5-data/pretrained_models/t5x/t5_1_1_large',
'T5_XL': 'gs://t5-data/pretrained_models/t5x/t5_1_1_xl',
'T5_XXL': 'gs://t5-data/pretrained_models/t5x/t5_1_1_xxl',
}
if reference_model_key not in predefined_model_paths:
raise ValueError(
f'No metadata found for `{reference_model_key}`. '
f'Base model must be one of {list(predefined_model_paths.keys())}.'
)
# Mapping from base model to its corresponding reward model.
reference_model_to_reward_model = {
'PALM_TINY': 'PALM_TINY',
'GECKO': 'GECKO',
'OTTER': 'OTTER',
'BISON': 'OTTER',
'ELEPHANT': 'ELEPHANT',
'T5_SMALL': 'T5_SMALL',
'T5_LARGE': 'T5_LARGE',
'T5_XL': 'T5_XL',
'T5_XXL': 'T5_XXL',
}
reward_model_key = reference_model_to_reward_model[reference_model_key]
return outputs(
large_model_reference=reference_model_key,
large_model_reference=reference_model.large_model_reference,
reference_model_path=(
reference_model_path or predefined_model_paths[reference_model_key]
reference_model_path or reference_model.reference_model_path
),
reward_model_reference=reward_model_key,
reward_model_path=predefined_reward_model_paths[reward_model_key],
reward_model_reference=reference_model.reward_model_reference,
reward_model_path=reference_model.reward_model_path,
)

View File

@ -46,9 +46,9 @@ def infer_pipeline(
Args:
large_model_reference: Name of the base model. Supported values are
``BISON``, ``T5_SMALL``, ``T5_LARGE``, ``T5_XL``, and ``T5_XXL``.
``BISON`` and ``T5_SMALL`` are supported in ``us-central1` and
``europe-west4``. ``T5_LARGE``, ``T5_XL`` and ``T5_XXL`` are only
``text-bison@001``, ``t5-small``, ``t5-large``, ``t5-xl`` and ``t5-xxl``.
``text-bison@001`` and ``t5-small`` are supported in ``us-central1` and
``europe-west4``. ``t5-large``, ``t5-xl`` and ``t5-xxl`` are only
supported in ``europe-west4``.
model_checkpoint: Cloud storage path to the model checkpoint.
prompt_dataset: Cloud storage path to an unlabled prompt dataset used for

View File

@ -68,9 +68,9 @@ def rlhf_pipeline(
the prompt, ``candidate_0`` and ``candidate_1`` that contain candidate
responses, ``choice`` that specifies the preferred candidate.
large_model_reference: Name of the base model. Supported values are
``BISON``, ``T5_SMALL``, ``T5_LARGE``, ``T5_XL``, and ``T5_XXL``.
``BISON`` and ``T5_SMALL`` are supported in ``us-central1` and
``europe-west4``. ``T5_LARGE``, ``T5_XL`` and ``T5_XXL`` are only
``text-bison@001``, ``t5-small``, ``t5-large``, ``t5-xl`` and ``t5-xxl``.
``text-bison@001`` and ``t5-small`` are supported in ``us-central1` and
``europe-west4``. ``t5-large``, ``t5-xl`` and ``t5-xxl`` are only
supported in ``europe-west4``.
model_display_name: Name of the fine-tuned model shown in the Model
Registry. If not provided, a default name will be created.