vllm/tests/v1/sample/test_logits_processors.py

627 lines
25 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import random
from collections.abc import Callable
from typing import NamedTuple, Optional, Union
import numpy as np
import pytest
import torch
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
# yapf: disable
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"{current_platform.device_type}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
MIN_TOKENS_LEN_THRESHOLD = 5
REQS_PER_LOGITPROC = 50
STR_NO_LOGITPROC = "none"
# LogitsProcessor subclass or "none"
LogitprocType = Union[type[LogitsProcessor], str]
class LogitsProcsRequestParams:
"""Encapsulates key params for a single request in a batch.
Params can be customized based on the enabled logitproc
"""
workload_index: int
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
out_tokens: list[int] # Output tokens required for min tokens test
params: SamplingParams # Settings customized for logitproc
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
self.workload_index = workload_index
self.logitproc_type = logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
self.params = _sampling_params_from_logitproc(logitproc_type)
def __str__(self):
"""For debugging"""
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
return f"MyClass({summ})"
def _generate_fake_sampling_metadata(
num_output_tokens: int,
batch_size: int,
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
"""Generate fake sampling metadata with fake logitsprocs"""
output_token_ids: list[list[int]] = []
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
logitsprocs = init_builtin_logitsprocs(
pin_memory_available=PIN_MEMORY_AVAILABLE,
max_num_reqs=MAX_NUM_REQS + 1,
device=device)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=logitsprocs)
return fake_sampling_metadata
def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
"""Generate fake logits and sampling metadata"""
fake_logits = create_fake_logits(batch_size, VOCAB_SIZE)
# Create one dominant token per batch, to support min-p test
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _generate_fake_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
return LogitsprocsTestFakes(
logits=fake_logits,
sampling_metadata=sampling_metadata,
)
def _sampling_params_from_logitproc(
logitproc_type: LogitprocType) -> SamplingParams:
"""Customize request SamplingParams for a specified logitproc"""
# SamplingParams for req with no logitproc
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn:
fxn(kwargs)
return SamplingParams(**kwargs)
def _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc: int,
logitsprocs_types: list[str],
) -> list[LogitsProcsRequestParams]:
"""Define key params for a batch of requests with a different
logitproc enabled per request.
The batch will have `reqs_per_logitproc` repeats for all
`logitsprocs_types` under test, including the case where
no logitsproc is enabled. The batch is randomly shuffled. The
size of the batch is `reqs_per_logitproc` times
`n = len(logitsprocs_types)`
Args:
reqs_per_logitproc: number of requests using each logitproc
logitsprocs_types: logitsprocs under test
Returns:
List of per-request params which configure the engine for that request's
enabled logitproc
"""
batch_size = len(logitsprocs_types) * reqs_per_logitproc
# Generate multiple repeats of key params for each logitproc;
# apply random inverse permutation to the iteration
# over logitsprocs, such that logitsprocs are shuffled.
batch_perm = random.sample(range(batch_size), k=batch_size)
return [
LogitsProcsRequestParams(
workload_index=idx,
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
for idx, pdx in enumerate(batch_perm)
]
def _raise_error_invalid(
msg_suffix: str,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
err_cls: type[Exception] = ValueError,
) -> None:
raise err_cls(f"Validation failed for step={step_idx}, "
f"batch_index={batch_index}, "
f"workload_index={request_params.workload_index}, "
f"req_params={request_params}. Reason: {msg_suffix}")
def _logit_bias_params(kwargs: dict) -> None:
"""Logit bias config"""
kwargs["logit_bias"] = {
random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2])
}
def _logit_bias_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate logit bias logitproc applied correctly"""
logit_bias = request_params.params.logit_bias
logits_old = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
logits_new = logits_new[batch_index].cpu()
for token_id in range(VOCAB_SIZE):
logit_old_value = logits_old[token_id]
logit_new_value = logits_new[token_id]
if token_id in logit_bias:
bias_value = logit_bias[token_id]
exp_value = bias_value + logit_old_value
if logit_new_value != pytest.approx(exp_value):
_raise_error_invalid(msg_suffix=(
f"Biased token {token_id} logit value {logit_new_value} "
f"does not match expected value {exp_value} "
f"given bias {bias_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if logit_new_value != pytest.approx(logit_old_value):
_raise_error_invalid(msg_suffix=(
f"Unbiased token {token_id} logit value {logit_new_value} "
f"does not match expected value {logit_old_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _min_p_params(kwargs: dict) -> None:
"""Min-p logitproc config"""
kwargs["min_p"] = 0.1
def _min_p_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate min-p logitproc applied correctly"""
for token_id in range(VOCAB_SIZE):
logits_for_token = logits_new[batch_index][token_id]
if token_id == 0:
# Dominant token should always be unmasked
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix="Invalid: dominant token 0 masked (-inf)",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if request_params.params.min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: non-dominant token {token_id} not masked",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
# No masking when min_p is 0
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: token {token_id} masked when min_p=0.0",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _min_tokens_params(kwargs: dict) -> None:
"""Min-tokens logitproc config"""
kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD
kwargs["stop_token_ids"] = [
np.random.randint(0, VOCAB_SIZE - 1)
for _ in range(np.random.randint(0, VOCAB_SIZE))
]
def _min_tokens_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate min-tokens logitsproc applied correctly"""
ref_num_out_tokens = len(request_params.out_tokens)
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
mt_lp: MinTokensLogitsProcessor = next(
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
assert isinstance(mt_lp, MinTokensLogitsProcessor)
min_tok = mt_lp.min_toks.get(batch_index, None)
# Validate min-token logits processor state
if min_tok:
(_, out_tok, all_stop_token_ids) = min_tok
num_out_tokens = len(out_tok)
if num_out_tokens != ref_num_out_tokens:
_raise_error_invalid(msg_suffix=(
"Number of output tokens in min-token logit processor "
f"request metadata ({num_out_tokens}) does not match "
f"reference ({ref_num_out_tokens})."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
if ref_all_stop_token_ids != all_stop_token_ids:
_raise_error_invalid(msg_suffix=(
"Stop token ids do not match reference; all_stop_token_ids: "
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
f"{sorted(ref_all_stop_token_ids)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
if min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
elif not min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
# Validate min-token logits
for token_id in range(VOCAB_SIZE):
logits_for_token = logits_new[batch_index][token_id]
if token_id in ref_all_stop_token_ids and not min_reached:
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f"(logit={logits_for_token})"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} should not be masked but "
f"is (output len={ref_num_out_tokens})"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _none_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate that no logits processors are applied"""
logits = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
ref_logits = logits_new[batch_index]
if not torch.all(ref_logits == logits):
mismatch_toks = (ref_logits
!= logits).nonzero(as_tuple=True)[0].tolist()
mismatch_strs = []
for token in mismatch_toks:
val = float(logits[token])
ref_val = float(ref_logits[token])
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
_raise_error_invalid(msg_suffix=(
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
class LogitsprocTestHelpers(NamedTuple):
"""Supports setting up and validating logitsprocs unit tests."""
eval_fxn: Callable
gen_request_fxn: Optional[Callable] = None
logitsprocs_test_mapping = {
STR_NO_LOGITPROC:
LogitsprocTestHelpers(eval_fxn=_none_validate),
LogitBiasLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
eval_fxn=_logit_bias_validate),
MinPLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
eval_fxn=_min_p_validate),
MinTokensLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
eval_fxn=_min_tokens_validate),
}
def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys())
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in logitsprocs_types
if logitproc_type != STR_NO_LOGITPROC
] + [logitsprocs_types]
def _generate_fake_step_update(
persistent_batch: list[LogitsProcsRequestParams],
workload_params: list[LogitsProcsRequestParams],
wdx: int,
batch_update_builder: BatchUpdateBuilder,
) -> tuple[Optional[BatchUpdate], int, int]:
batch_size = len(persistent_batch)
workload_size = len(workload_params)
workload_reqs_remaining = workload_size - wdx
max_add_remove_per_step = max(1, int(0.2 * workload_size))
# 50% of steps: add no reqs
# Other 50%: add a limited number of reqs (less than the number
# of workload reqs remaining, less than an arbitrary max)
# If no workload reqs remain: 100% of steps have 0 adds
num_step_add = random.choice([
0,
random.randint(1, min(max_add_remove_per_step,
workload_reqs_remaining))
]) if workload_reqs_remaining else 0
# 50% of steps: remove no requests
# Other 50%: remove a limited number of reqs (less than the number
# persistent batch reqs remaining, less than an arbitrary max)
# If persistent batch is empty: 100% of steps have 0 removals until
# more requests are added. Assume that removed requests are always
# drawn from the current batch, before new adds
num_step_remove = random.choice([
0, random.randint(1, min(max_add_remove_per_step, batch_size))
]) if batch_size else 0
num_step_add_replace = min(num_step_add, num_step_remove)
# Generate fake removed request indices drawn from persistent batch indices
for removal in random.sample(range(batch_size), num_step_remove):
batch_update_builder.removed_append(removal)
# Get added requests from workload
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
# Replace as many removed requests as possible with added requests
add_remove_idx = batch_update_builder.pop_removed()
batch_update_builder.added.append(
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
persistent_batch[add_remove_idx] = add_req_params
# Append remaining added requests to end of batch
add_reqs_append = workload_params[(wdx +
num_step_add_replace):(wdx +
num_step_add)]
batch_update_builder.added.extend([
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
for adx, add_req_params in enumerate(add_reqs_append)
])
persistent_batch.extend(add_reqs_append)
pre_condense_batch_size = len(persistent_batch)
wdx += num_step_add # Update workload offset
# Simulate condensing persistent batch
last_nonempty_index = pre_condense_batch_size - 1
condensed_to_idxs = set()
while batch_update_builder.removed:
if (last_nonempty_index in batch_update_builder.removed
or last_nonempty_index in condensed_to_idxs):
last_nonempty_index -= 1
continue
# last_nonempty_index is the highest persistent batch index that was
# not removed
first_empty_index = batch_update_builder.peek_removed()
assert first_empty_index is not None
if first_empty_index > last_nonempty_index:
break
# first_empty_index is the lowest removed persistent batch index
# that is less than last_nonempty_index
#
# move last_nonempty_index -> first_empty_index
batch_update_builder.pop_removed()
condensed_to_idxs.add(first_empty_index)
persistent_batch[first_empty_index] = persistent_batch[
last_nonempty_index]
batch_update_builder.moved.append(
(last_nonempty_index, first_empty_index,
MoveDirectionality.UNIDIRECTIONAL))
last_nonempty_index -= 1
# Now removed requests & gaps left by non-removed requests that got
# moved downward are grouped consecutively in the upper indices of
# the persistent batch. Truncate them to get condensed persistent batch
condensed_batch_size = batch_size + num_step_add - num_step_remove
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
if condensed_batch_size > 1:
# Simulate arbitrary reorder_batch() in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size))
random.shuffle(idxs)
swaps = [
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
]
batch_update_builder.moved.extend([
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
])
for adx, bdx in swaps:
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
bdx], persistent_batch[adx]
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
workload_size - wdx)
def _assert_valid(
batch_size: int,
persistent_batch: list[LogitsProcsRequestParams],
test_fakes: LogitsprocsTestFakes,
slice_idxs: list[int],
logits_w_lp: torch.Tensor,
step_idx: int,
) -> None:
if not slice_idxs:
# Trivial case of empty persistent batch
assert len(persistent_batch) == 0
if logits_w_lp.shape[0] != 0:
raise ValueError("Fake persistent batch is empty but logitsprocs "
f"output batch has shape {logits_w_lp.shape}")
return
# Validate logits for each fake request
for batch_index in range(batch_size):
request_params = persistent_batch[batch_index]
# Invoke the appropriate validation function for
# the logitproc employed by this request
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
fxn(test_fakes=test_fakes,
persistent_batch=persistent_batch,
logits_new=logits_w_lp,
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
def test_logitsprocs(device: str, reqs_per_logitproc: int,
logitsprocs_under_test: list[str]):
random.seed(40)
torch.set_default_device(device)
# Define a shuffled batch of requests which individually use a different
# logitproc, or no logitproc at all
workload_params = _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc=reqs_per_logitproc,
logitsprocs_types=logitsprocs_under_test)
workload_size = len(workload_params)
# Create fake test data structures for testing.
test_fakes = _generate_test_fakes(workload_size, device)
wdx = 0 # Next request index in workload to add
persistent_batch: list[LogitsProcsRequestParams] = [
] # Persistent batch state, as list of workload indices
# Generate fake removed request indices from current persistent
# batch before adds
batch_update_builder = BatchUpdateBuilder()
# Break when entire workload has been added previously and persistent
# batch is empty
workload_reqs_remaining = workload_size
batch_size = 0
step_idx = 0
while True:
if not (workload_reqs_remaining or batch_size):
break
(
batch_update,
wdx,
workload_reqs_remaining,
) = _generate_fake_step_update(
persistent_batch=persistent_batch,
workload_params=workload_params,
wdx=wdx,
batch_update_builder=batch_update_builder,
)
batch_size = len(persistent_batch)
# Apply fake batch update to logitsprocs
fake_update_logitsprocs_state(test_fakes, batch_update)
# Emulate application of logits processors in engine
slice_idxs = [req.workload_index for req in persistent_batch]
logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu()
_assert_valid(
batch_size=batch_size,
persistent_batch=persistent_batch,
test_fakes=test_fakes,
slice_idxs=slice_idxs,
logits_w_lp=logits_w_lp,
step_idx=step_idx,
)
step_idx += 1