mirror of https://github.com/vllm-project/vllm.git
627 lines
25 KiB
Python
627 lines
25 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import random
|
|
from collections.abc import Callable
|
|
from typing import NamedTuple, Optional, Union
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
|
|
create_penalty_tensor,
|
|
create_prompt_tokens_tensor,
|
|
fake_apply_logitsprocs,
|
|
fake_update_logitsprocs_state)
|
|
from vllm.platforms import current_platform
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils import is_pin_memory_available
|
|
# yapf: disable
|
|
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
|
|
LogitBiasLogitsProcessor,
|
|
LogitsProcessor,
|
|
MinPLogitsProcessor,
|
|
MinTokensLogitsProcessor,
|
|
MoveDirectionality,
|
|
init_builtin_logitsprocs)
|
|
# yapf: enable
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
|
|
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
|
|
MAX_NUM_REQS = 256
|
|
VOCAB_SIZE = 1024
|
|
NUM_OUTPUT_TOKENS = 20
|
|
CUDA_DEVICES = [
|
|
f"{current_platform.device_type}:{i}"
|
|
for i in range(1 if current_platform.device_count() == 1 else 2)
|
|
]
|
|
MAX_NUM_PROMPT_TOKENS = 64
|
|
MIN_TOKENS_LEN_THRESHOLD = 5
|
|
REQS_PER_LOGITPROC = 50
|
|
STR_NO_LOGITPROC = "none"
|
|
|
|
# LogitsProcessor subclass or "none"
|
|
LogitprocType = Union[type[LogitsProcessor], str]
|
|
|
|
|
|
class LogitsProcsRequestParams:
|
|
"""Encapsulates key params for a single request in a batch.
|
|
|
|
Params can be customized based on the enabled logitproc
|
|
"""
|
|
workload_index: int
|
|
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
|
|
out_tokens: list[int] # Output tokens required for min tokens test
|
|
params: SamplingParams # Settings customized for logitproc
|
|
|
|
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
|
|
self.workload_index = workload_index
|
|
self.logitproc_type = logitproc_type
|
|
# Number of output tokens is randomly 0 or twice the min-tokens
|
|
# threshold which will be used in testing. Output token values
|
|
# don't matter *for these tests* so use 0 as a dummy value
|
|
self.out_tokens = ([0] *
|
|
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
|
|
self.params = _sampling_params_from_logitproc(logitproc_type)
|
|
|
|
def __str__(self):
|
|
"""For debugging"""
|
|
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
|
|
return f"MyClass({summ})"
|
|
|
|
|
|
def _generate_fake_sampling_metadata(
|
|
num_output_tokens: int,
|
|
batch_size: int,
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
) -> SamplingMetadata:
|
|
"""Generate fake sampling metadata with fake logitsprocs"""
|
|
output_token_ids: list[list[int]] = []
|
|
prompt_token_ids: list[list[int]] = []
|
|
for _ in range(batch_size):
|
|
output_token_ids.append(
|
|
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
|
prompt_token_ids.append(
|
|
np.random.randint(0,
|
|
vocab_size,
|
|
size=np.random.randint(
|
|
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
|
logitsprocs = init_builtin_logitsprocs(
|
|
pin_memory_available=PIN_MEMORY_AVAILABLE,
|
|
max_num_reqs=MAX_NUM_REQS + 1,
|
|
device=device)
|
|
|
|
fake_sampling_metadata = SamplingMetadata(
|
|
temperature=torch.full((batch_size, ), 0.0),
|
|
all_greedy=True,
|
|
all_random=False,
|
|
top_p=None,
|
|
top_k=None,
|
|
generators={},
|
|
max_num_logprobs=0,
|
|
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
|
|
vocab_size, device),
|
|
output_token_ids=output_token_ids,
|
|
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
|
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
|
repetition_penalties=create_penalty_tensor(batch_size, 1.0, device),
|
|
no_penalties=True,
|
|
allowed_token_ids_mask=None,
|
|
bad_words_token_ids={},
|
|
logitsprocs=logitsprocs)
|
|
return fake_sampling_metadata
|
|
|
|
|
|
def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
|
|
"""Generate fake logits and sampling metadata"""
|
|
fake_logits = create_fake_logits(batch_size, VOCAB_SIZE)
|
|
# Create one dominant token per batch, to support min-p test
|
|
for i in range(batch_size):
|
|
fake_logits[i, 0] = 10.0 # High logit for first token
|
|
fake_logits[i, 1:] = 1e-2 # Others remain low
|
|
sampling_metadata = _generate_fake_sampling_metadata(
|
|
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
|
return LogitsprocsTestFakes(
|
|
logits=fake_logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
|
|
|
|
def _sampling_params_from_logitproc(
|
|
logitproc_type: LogitprocType) -> SamplingParams:
|
|
"""Customize request SamplingParams for a specified logitproc"""
|
|
# SamplingParams for req with no logitproc
|
|
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
|
|
if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn:
|
|
fxn(kwargs)
|
|
return SamplingParams(**kwargs)
|
|
|
|
|
|
def _generate_mixed_logitsprocs_batch_params(
|
|
reqs_per_logitproc: int,
|
|
logitsprocs_types: list[str],
|
|
) -> list[LogitsProcsRequestParams]:
|
|
"""Define key params for a batch of requests with a different
|
|
logitproc enabled per request.
|
|
|
|
The batch will have `reqs_per_logitproc` repeats for all
|
|
`logitsprocs_types` under test, including the case where
|
|
no logitsproc is enabled. The batch is randomly shuffled. The
|
|
size of the batch is `reqs_per_logitproc` times
|
|
`n = len(logitsprocs_types)`
|
|
|
|
Args:
|
|
reqs_per_logitproc: number of requests using each logitproc
|
|
logitsprocs_types: logitsprocs under test
|
|
|
|
Returns:
|
|
List of per-request params which configure the engine for that request's
|
|
enabled logitproc
|
|
"""
|
|
batch_size = len(logitsprocs_types) * reqs_per_logitproc
|
|
# Generate multiple repeats of key params for each logitproc;
|
|
# apply random inverse permutation to the iteration
|
|
# over logitsprocs, such that logitsprocs are shuffled.
|
|
batch_perm = random.sample(range(batch_size), k=batch_size)
|
|
return [
|
|
LogitsProcsRequestParams(
|
|
workload_index=idx,
|
|
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
|
|
for idx, pdx in enumerate(batch_perm)
|
|
]
|
|
|
|
|
|
def _raise_error_invalid(
|
|
msg_suffix: str,
|
|
batch_index: int,
|
|
request_params: LogitsProcsRequestParams,
|
|
step_idx: int,
|
|
err_cls: type[Exception] = ValueError,
|
|
) -> None:
|
|
raise err_cls(f"Validation failed for step={step_idx}, "
|
|
f"batch_index={batch_index}, "
|
|
f"workload_index={request_params.workload_index}, "
|
|
f"req_params={request_params}. Reason: {msg_suffix}")
|
|
|
|
|
|
def _logit_bias_params(kwargs: dict) -> None:
|
|
"""Logit bias config"""
|
|
kwargs["logit_bias"] = {
|
|
random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2])
|
|
}
|
|
|
|
|
|
def _logit_bias_validate(
|
|
test_fakes: LogitsprocsTestFakes,
|
|
persistent_batch: list[LogitsProcsRequestParams],
|
|
logits_new: torch.Tensor,
|
|
batch_index: int,
|
|
request_params: LogitsProcsRequestParams,
|
|
step_idx: int,
|
|
) -> None:
|
|
"""Validate logit bias logitproc applied correctly"""
|
|
logit_bias = request_params.params.logit_bias
|
|
logits_old = (
|
|
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
|
|
logits_new = logits_new[batch_index].cpu()
|
|
for token_id in range(VOCAB_SIZE):
|
|
logit_old_value = logits_old[token_id]
|
|
logit_new_value = logits_new[token_id]
|
|
if token_id in logit_bias:
|
|
bias_value = logit_bias[token_id]
|
|
exp_value = bias_value + logit_old_value
|
|
if logit_new_value != pytest.approx(exp_value):
|
|
_raise_error_invalid(msg_suffix=(
|
|
f"Biased token {token_id} logit value {logit_new_value} "
|
|
f"does not match expected value {exp_value} "
|
|
f"given bias {bias_value}"),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
|
|
else:
|
|
if logit_new_value != pytest.approx(logit_old_value):
|
|
_raise_error_invalid(msg_suffix=(
|
|
f"Unbiased token {token_id} logit value {logit_new_value} "
|
|
f"does not match expected value {logit_old_value}"),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
|
|
|
|
def _min_p_params(kwargs: dict) -> None:
|
|
"""Min-p logitproc config"""
|
|
kwargs["min_p"] = 0.1
|
|
|
|
|
|
def _min_p_validate(
|
|
test_fakes: LogitsprocsTestFakes,
|
|
persistent_batch: list[LogitsProcsRequestParams],
|
|
logits_new: torch.Tensor,
|
|
batch_index: int,
|
|
request_params: LogitsProcsRequestParams,
|
|
step_idx: int,
|
|
) -> None:
|
|
"""Validate min-p logitproc applied correctly"""
|
|
for token_id in range(VOCAB_SIZE):
|
|
logits_for_token = logits_new[batch_index][token_id]
|
|
if token_id == 0:
|
|
# Dominant token should always be unmasked
|
|
if logits_for_token == -float("inf"):
|
|
_raise_error_invalid(
|
|
msg_suffix="Invalid: dominant token 0 masked (-inf)",
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
else:
|
|
if request_params.params.min_p > 0.0:
|
|
# Non-dominant tokens should be masked when min_p > 0
|
|
if logits_for_token != -float("inf"):
|
|
_raise_error_invalid(
|
|
msg_suffix=
|
|
f"Invalid: non-dominant token {token_id} not masked",
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
else:
|
|
# No masking when min_p is 0
|
|
if logits_for_token == -float("inf"):
|
|
_raise_error_invalid(
|
|
msg_suffix=
|
|
f"Invalid: token {token_id} masked when min_p=0.0",
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
|
|
|
|
def _min_tokens_params(kwargs: dict) -> None:
|
|
"""Min-tokens logitproc config"""
|
|
kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD
|
|
kwargs["stop_token_ids"] = [
|
|
np.random.randint(0, VOCAB_SIZE - 1)
|
|
for _ in range(np.random.randint(0, VOCAB_SIZE))
|
|
]
|
|
|
|
|
|
def _min_tokens_validate(
|
|
test_fakes: LogitsprocsTestFakes,
|
|
persistent_batch: list[LogitsProcsRequestParams],
|
|
logits_new: torch.Tensor,
|
|
batch_index: int,
|
|
request_params: LogitsProcsRequestParams,
|
|
step_idx: int,
|
|
) -> None:
|
|
"""Validate min-tokens logitsproc applied correctly"""
|
|
ref_num_out_tokens = len(request_params.out_tokens)
|
|
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
|
|
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
|
|
mt_lp: MinTokensLogitsProcessor = next(
|
|
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
|
|
assert isinstance(mt_lp, MinTokensLogitsProcessor)
|
|
min_tok = mt_lp.min_toks.get(batch_index, None)
|
|
|
|
# Validate min-token logits processor state
|
|
if min_tok:
|
|
(_, out_tok, all_stop_token_ids) = min_tok
|
|
num_out_tokens = len(out_tok)
|
|
if num_out_tokens != ref_num_out_tokens:
|
|
_raise_error_invalid(msg_suffix=(
|
|
"Number of output tokens in min-token logit processor "
|
|
f"request metadata ({num_out_tokens}) does not match "
|
|
f"reference ({ref_num_out_tokens})."),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
if ref_all_stop_token_ids != all_stop_token_ids:
|
|
_raise_error_invalid(msg_suffix=(
|
|
"Stop token ids do not match reference; all_stop_token_ids: "
|
|
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
|
|
f"{sorted(ref_all_stop_token_ids)}"),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
if min_reached:
|
|
_raise_error_invalid(msg_suffix=(
|
|
"Expected min-tokens request with min reached, but batch "
|
|
"index is recognized by min-tokens logits processor."),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx,
|
|
err_cls=RuntimeError)
|
|
|
|
elif not min_reached:
|
|
_raise_error_invalid(msg_suffix=(
|
|
"Expected min-tokens request with min not reached, but batch "
|
|
"index is not recognized by min-tokens logits processor."),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx,
|
|
err_cls=RuntimeError)
|
|
|
|
# Validate min-token logits
|
|
for token_id in range(VOCAB_SIZE):
|
|
logits_for_token = logits_new[batch_index][token_id]
|
|
if token_id in ref_all_stop_token_ids and not min_reached:
|
|
if logits_for_token != -float("inf"):
|
|
_raise_error_invalid(
|
|
msg_suffix=(f"Token {token_id} is a stop token and "
|
|
"the sequence has not reached min length, "
|
|
"but the token is not masked "
|
|
f"(logit={logits_for_token})"),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
else:
|
|
if logits_for_token == -float("inf"):
|
|
_raise_error_invalid(
|
|
msg_suffix=(f"Token {token_id} should not be masked but "
|
|
f"is (output len={ref_num_out_tokens})"),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
|
|
|
|
def _none_validate(
|
|
test_fakes: LogitsprocsTestFakes,
|
|
persistent_batch: list[LogitsProcsRequestParams],
|
|
logits_new: torch.Tensor,
|
|
batch_index: int,
|
|
request_params: LogitsProcsRequestParams,
|
|
step_idx: int,
|
|
) -> None:
|
|
"""Validate that no logits processors are applied"""
|
|
logits = (
|
|
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
|
|
ref_logits = logits_new[batch_index]
|
|
if not torch.all(ref_logits == logits):
|
|
mismatch_toks = (ref_logits
|
|
!= logits).nonzero(as_tuple=True)[0].tolist()
|
|
mismatch_strs = []
|
|
for token in mismatch_toks:
|
|
val = float(logits[token])
|
|
ref_val = float(ref_logits[token])
|
|
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
|
|
_raise_error_invalid(msg_suffix=(
|
|
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
|
|
|
|
class LogitsprocTestHelpers(NamedTuple):
|
|
"""Supports setting up and validating logitsprocs unit tests."""
|
|
eval_fxn: Callable
|
|
gen_request_fxn: Optional[Callable] = None
|
|
|
|
|
|
logitsprocs_test_mapping = {
|
|
STR_NO_LOGITPROC:
|
|
LogitsprocTestHelpers(eval_fxn=_none_validate),
|
|
LogitBiasLogitsProcessor:
|
|
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
|
|
eval_fxn=_logit_bias_validate),
|
|
MinPLogitsProcessor:
|
|
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
|
|
eval_fxn=_min_p_validate),
|
|
MinTokensLogitsProcessor:
|
|
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
|
|
eval_fxn=_min_tokens_validate),
|
|
}
|
|
|
|
|
|
def _get_test_cases() -> list[list[str]]:
|
|
"""Each test case is a set of logitsprocs"""
|
|
logitsprocs_types = list(logitsprocs_test_mapping.keys())
|
|
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
|
|
for logitproc_type in logitsprocs_types
|
|
if logitproc_type != STR_NO_LOGITPROC
|
|
] + [logitsprocs_types]
|
|
|
|
|
|
def _generate_fake_step_update(
|
|
persistent_batch: list[LogitsProcsRequestParams],
|
|
workload_params: list[LogitsProcsRequestParams],
|
|
wdx: int,
|
|
batch_update_builder: BatchUpdateBuilder,
|
|
) -> tuple[Optional[BatchUpdate], int, int]:
|
|
batch_size = len(persistent_batch)
|
|
workload_size = len(workload_params)
|
|
workload_reqs_remaining = workload_size - wdx
|
|
max_add_remove_per_step = max(1, int(0.2 * workload_size))
|
|
|
|
# 50% of steps: add no reqs
|
|
# Other 50%: add a limited number of reqs (less than the number
|
|
# of workload reqs remaining, less than an arbitrary max)
|
|
# If no workload reqs remain: 100% of steps have 0 adds
|
|
num_step_add = random.choice([
|
|
0,
|
|
random.randint(1, min(max_add_remove_per_step,
|
|
workload_reqs_remaining))
|
|
]) if workload_reqs_remaining else 0
|
|
|
|
# 50% of steps: remove no requests
|
|
# Other 50%: remove a limited number of reqs (less than the number
|
|
# persistent batch reqs remaining, less than an arbitrary max)
|
|
# If persistent batch is empty: 100% of steps have 0 removals until
|
|
# more requests are added. Assume that removed requests are always
|
|
# drawn from the current batch, before new adds
|
|
num_step_remove = random.choice([
|
|
0, random.randint(1, min(max_add_remove_per_step, batch_size))
|
|
]) if batch_size else 0
|
|
|
|
num_step_add_replace = min(num_step_add, num_step_remove)
|
|
|
|
# Generate fake removed request indices drawn from persistent batch indices
|
|
for removal in random.sample(range(batch_size), num_step_remove):
|
|
batch_update_builder.removed_append(removal)
|
|
|
|
# Get added requests from workload
|
|
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
|
|
# Replace as many removed requests as possible with added requests
|
|
add_remove_idx = batch_update_builder.pop_removed()
|
|
batch_update_builder.added.append(
|
|
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
|
|
persistent_batch[add_remove_idx] = add_req_params
|
|
|
|
# Append remaining added requests to end of batch
|
|
add_reqs_append = workload_params[(wdx +
|
|
num_step_add_replace):(wdx +
|
|
num_step_add)]
|
|
batch_update_builder.added.extend([
|
|
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
|
|
for adx, add_req_params in enumerate(add_reqs_append)
|
|
])
|
|
persistent_batch.extend(add_reqs_append)
|
|
pre_condense_batch_size = len(persistent_batch)
|
|
wdx += num_step_add # Update workload offset
|
|
|
|
# Simulate condensing persistent batch
|
|
last_nonempty_index = pre_condense_batch_size - 1
|
|
condensed_to_idxs = set()
|
|
while batch_update_builder.removed:
|
|
if (last_nonempty_index in batch_update_builder.removed
|
|
or last_nonempty_index in condensed_to_idxs):
|
|
last_nonempty_index -= 1
|
|
continue
|
|
# last_nonempty_index is the highest persistent batch index that was
|
|
# not removed
|
|
first_empty_index = batch_update_builder.peek_removed()
|
|
assert first_empty_index is not None
|
|
if first_empty_index > last_nonempty_index:
|
|
break
|
|
# first_empty_index is the lowest removed persistent batch index
|
|
# that is less than last_nonempty_index
|
|
#
|
|
# move last_nonempty_index -> first_empty_index
|
|
batch_update_builder.pop_removed()
|
|
condensed_to_idxs.add(first_empty_index)
|
|
persistent_batch[first_empty_index] = persistent_batch[
|
|
last_nonempty_index]
|
|
batch_update_builder.moved.append(
|
|
(last_nonempty_index, first_empty_index,
|
|
MoveDirectionality.UNIDIRECTIONAL))
|
|
|
|
last_nonempty_index -= 1
|
|
|
|
# Now removed requests & gaps left by non-removed requests that got
|
|
# moved downward are grouped consecutively in the upper indices of
|
|
# the persistent batch. Truncate them to get condensed persistent batch
|
|
condensed_batch_size = batch_size + num_step_add - num_step_remove
|
|
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
|
|
|
|
if condensed_batch_size > 1:
|
|
# Simulate arbitrary reorder_batch() in the kernel backend
|
|
# Generate a random number k of non-overlapping swap tuples
|
|
k = random.randint(0, condensed_batch_size // 2)
|
|
idxs = list(range(condensed_batch_size))
|
|
random.shuffle(idxs)
|
|
swaps = [
|
|
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
|
|
]
|
|
batch_update_builder.moved.extend([
|
|
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
|
|
])
|
|
for adx, bdx in swaps:
|
|
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
|
|
bdx], persistent_batch[adx]
|
|
|
|
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
|
|
workload_size - wdx)
|
|
|
|
|
|
def _assert_valid(
|
|
batch_size: int,
|
|
persistent_batch: list[LogitsProcsRequestParams],
|
|
test_fakes: LogitsprocsTestFakes,
|
|
slice_idxs: list[int],
|
|
logits_w_lp: torch.Tensor,
|
|
step_idx: int,
|
|
) -> None:
|
|
if not slice_idxs:
|
|
# Trivial case of empty persistent batch
|
|
assert len(persistent_batch) == 0
|
|
if logits_w_lp.shape[0] != 0:
|
|
raise ValueError("Fake persistent batch is empty but logitsprocs "
|
|
f"output batch has shape {logits_w_lp.shape}")
|
|
return
|
|
|
|
# Validate logits for each fake request
|
|
for batch_index in range(batch_size):
|
|
request_params = persistent_batch[batch_index]
|
|
# Invoke the appropriate validation function for
|
|
# the logitproc employed by this request
|
|
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
|
|
fxn(test_fakes=test_fakes,
|
|
persistent_batch=persistent_batch,
|
|
logits_new=logits_w_lp,
|
|
batch_index=batch_index,
|
|
request_params=request_params,
|
|
step_idx=step_idx)
|
|
|
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
|
|
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
|
|
def test_logitsprocs(device: str, reqs_per_logitproc: int,
|
|
logitsprocs_under_test: list[str]):
|
|
random.seed(40)
|
|
torch.set_default_device(device)
|
|
|
|
# Define a shuffled batch of requests which individually use a different
|
|
# logitproc, or no logitproc at all
|
|
workload_params = _generate_mixed_logitsprocs_batch_params(
|
|
reqs_per_logitproc=reqs_per_logitproc,
|
|
logitsprocs_types=logitsprocs_under_test)
|
|
workload_size = len(workload_params)
|
|
|
|
# Create fake test data structures for testing.
|
|
test_fakes = _generate_test_fakes(workload_size, device)
|
|
|
|
wdx = 0 # Next request index in workload to add
|
|
persistent_batch: list[LogitsProcsRequestParams] = [
|
|
] # Persistent batch state, as list of workload indices
|
|
|
|
# Generate fake removed request indices from current persistent
|
|
# batch before adds
|
|
batch_update_builder = BatchUpdateBuilder()
|
|
|
|
# Break when entire workload has been added previously and persistent
|
|
# batch is empty
|
|
workload_reqs_remaining = workload_size
|
|
batch_size = 0
|
|
step_idx = 0
|
|
while True:
|
|
if not (workload_reqs_remaining or batch_size):
|
|
break
|
|
|
|
(
|
|
batch_update,
|
|
wdx,
|
|
workload_reqs_remaining,
|
|
) = _generate_fake_step_update(
|
|
persistent_batch=persistent_batch,
|
|
workload_params=workload_params,
|
|
wdx=wdx,
|
|
batch_update_builder=batch_update_builder,
|
|
)
|
|
batch_size = len(persistent_batch)
|
|
|
|
# Apply fake batch update to logitsprocs
|
|
fake_update_logitsprocs_state(test_fakes, batch_update)
|
|
|
|
# Emulate application of logits processors in engine
|
|
slice_idxs = [req.workload_index for req in persistent_batch]
|
|
logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu()
|
|
|
|
_assert_valid(
|
|
batch_size=batch_size,
|
|
persistent_batch=persistent_batch,
|
|
test_fakes=test_fakes,
|
|
slice_idxs=slice_idxs,
|
|
logits_w_lp=logits_w_lp,
|
|
step_idx=step_idx,
|
|
)
|
|
|
|
step_idx += 1
|