# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import numpy as np import pytest import torch from vllm.platforms import current_platform from vllm.utils import make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 CUDA_DEVICES = [ f"{current_platform.device_type}:{i}" for i in range(1 if current_platform.device_count() == 1 else 2) ] MAX_NUM_PROMPT_TOKENS = 64 def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float) return fake_logits def _create_penalty_tensor(batch_size: int, penalty_value: float, device: torch.device) -> torch.Tensor: return torch.full((batch_size, ), fill_value=penalty_value, dtype=torch.float, device=device) def _create_prompt_tokens_tensor( prompt_token_ids: list[list[int]], vocab_size: int, device: torch.device, ) -> torch.Tensor: return make_tensor_with_pad( prompt_token_ids, pad=vocab_size, device=device, dtype=torch.int64, pin_memory=False, ) def _create_logit_bias( batch_size: int, vocab_size: int, bias_value: float, ) -> list[Optional[dict[int, float]]]: res: list[Optional[dict[int, float]]] = [] for i in range(batch_size): logit_bias = {min(i, vocab_size - 1): bias_value} res.append(logit_bias) return res def _create_allowed_token_ids( batch_size: int, vocab_size: int, num_allowed_token_ids: int, device: torch.device, ) -> Optional[torch.Tensor]: mask: Optional[torch.Tensor] = None for i in range(batch_size): if i % 2 == 1: continue if mask is None: mask = torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device) start = min(i, vocab_size - 1) end = min(i + num_allowed_token_ids, vocab_size - 1) mask[i, start:end] = True return mask def _create_bad_words_token_ids( batch_size: int, vocab_size: int, bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]: bad_words_token_ids = {} for batch_idx in range(batch_size): token_ids_single_batch = [] for bad_words_length in bad_words_lengths: token_ids = np.random.choice(vocab_size, size=bad_words_length, replace=True).tolist() token_ids_single_batch.append(token_ids) bad_words_token_ids[batch_idx] = token_ids_single_batch if batch_size >= 2: # Test no bad_words for some batch no_bad_words_batch_idx = np.random.choice(batch_size) bad_words_token_ids.pop(no_bad_words_batch_idx, None) return bad_words_token_ids def _update_output_token_ids_for_bad_words( metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: bad_words_last_tokens = {} for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items(): output_token_ids = metadata.output_token_ids[batch_idx] bad_words_last_token: list[int] = [] for i, bad_word_token_ids in enumerate(bad_words_token_ids): if len(bad_word_token_ids) == 1: # Single token id always affects logits bad_words_last_token.append(bad_word_token_ids[0]) else: prefix_length = len(bad_word_token_ids) - 1 has_bad_words = np.random.choice([True, False]) if has_bad_words: output_token_ids[-prefix_length:] = bad_word_token_ids[:-1] bad_words_last_token.append(bad_word_token_ids[-1]) break # Maximum one update to output_token_ids else: # Make sure no accidental match to bad words output_token_ids[-1] = (bad_word_token_ids[-2] + 1) % vocab_size bad_words_last_tokens[batch_idx] = bad_words_last_token return bad_words_last_tokens def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, vocab_size: int, device: torch.device, ) -> SamplingMetadata: output_token_ids: list[list[int]] = [] prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) prompt_token_ids.append( np.random.randint(0, vocab_size, size=np.random.randint( 1, MAX_NUM_PROMPT_TOKENS)).tolist()) fake_sampling_metadata = SamplingMetadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, all_random=False, top_p=None, top_k=None, min_p=None, generators={}, max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, min_tokens={}, logit_bias=[None] * batch_size, allowed_token_ids_mask=None, bad_words_token_ids={}, ) return fake_sampling_metadata def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, batch_indices_for_min_token_penalty: list[int] ) -> dict[int, tuple[int, set[int]]]: """ Generates and returns a dict of minimum token penalties and corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each batch. If a batch index is included in `batch_indices_for_min_token_penalty`, a higher `min_tokens` value is assigned (within a randomized range), and a random set of stop token IDs is created. Otherwise, a lower `min_tokens` value is assigned, and the stop token IDs set is empty. """ min_tokens: dict[int, tuple[int, set[int]]] = {} for index in range(batch_size): if index in batch_indices_for_min_token_penalty: min_tokens[index] = ( np.random.randint(num_output_tokens + 1, 2 * num_output_tokens), set( np.random.randint(0, vocab_size - 1) for _ in range(np.random.randint(0, vocab_size)))) else: min_tokens[index] = (np.random.randint(0, num_output_tokens), set()) return min_tokens def _create_weighted_output_token_list( batch_size: int, vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: """ Creates an output token list where each token occurs a distinct number of times. For each batch, a random subset of token IDs is selected from the vocabulary. The selected tokens are then added to the output token list, each with a different frequency. Returns: tuple[list[list[int]], list[list[int]]]: - The first element is the output token list, where each sublist corresponds to a batch and contains tokens with weighted frequencies. - The second element is a list of distinct token IDs for each batch, ordered by their frequency in the corresponding output list. """ output_token_ids: list[list[int]] = [] sorted_token_ids_in_output: list[list[int]] = [] for _ in range(batch_size): distinct_token_ids = np.random.choice(vocab_size, size=np.random.randint(1, 10), replace=False).tolist() sorted_token_ids_in_output.append(distinct_token_ids) output_token_ids_for_batch = [] for index, token_id in enumerate(distinct_token_ids): output_token_ids_for_batch.extend( [token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) return output_token_ids, sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) def test_sampler_min_tokens_penalty(device: str, batch_size: int): """ Tests that if the number of output tokens is less than SamplingParams.min_tokens then we will set the logits for the stop token ids to -inf. """ torch.set_default_device(device) fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) batch_indices_for_min_token_penalty = np.random.randint( 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() min_tokens = _generate_min_token_penalties_and_stop_tokens( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) sampling_metadata.min_tokens = min_tokens sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): for token_id in range(VOCAB_SIZE): _, stop_token_ids = min_tokens.get(batch_idx, (0, set())) if token_id in stop_token_ids: assert logits[batch_idx][token_id] == -float("inf") else: assert logits[batch_idx][token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) def test_sampler_presence_penalty(device: str, batch_size: int, presence_penalty: float): """ Test to verify that if presence penalty is enabled then tokens are penalized as per their presence in the existing output. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) output_token_ids = sampling_metadata.output_token_ids sampling_metadata.presence_penalties = _create_penalty_tensor( batch_size, presence_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): # Since all tokens initially have the same logits, the non-penalized # token ID will be the one with the highest logit value, while the # penalized token ID will be the one with the lowest logit value. non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() if presence_penalty > 0: # If `presence_penalty` is set to a value greater than 0, it # indicates a preference for new tokens over those already # present in the output. # Verify that the penalized token ID exists in the output, while the # non-penalized token ID does not. assert penalized_token_id in output_token_ids[batch_idx] assert non_penalized_token_id not in output_token_ids[batch_idx] elif presence_penalty < 0: # If `presence_penalty` is set to a value less than 0, it indicates # a preference for existing tokens over new ones. Verify that the # non-penalized token ID exists in the output, while the penalized # token ID does not. assert non_penalized_token_id in output_token_ids[batch_idx] assert penalized_token_id not in output_token_ids[batch_idx] @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) def test_sampler_frequency_penalty(device: str, batch_size: int, frequency_penalty: float): """ Test to verify that if frequency penalty is enabled then tokens are penalized as per their frequency of occurrence. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) sampling_metadata.frequency_penalties = _create_penalty_tensor( batch_size, frequency_penalty, torch.device(device)) output_token_ids, sorted_token_ids_in_output = \ _create_weighted_output_token_list( batch_size, VOCAB_SIZE, ) sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[ batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ len(distinct_sorted_token_ids_in_output) - 1] if frequency_penalty > 0: # If `frequency_penalty` is set to > 0, it indicates # a preference for new tokens over existing ones. Verify that the # non-penalized token ID is not present in the output, while the # most penalized token is the one that occurs most frequently in # the output. assert (non_penalized_token_id not in distinct_sorted_token_ids_in_output) assert penalized_token_id == most_frequent_token_id elif frequency_penalty < 0: # If `frequency_penalty` is set to < 0, it indicates # a preference for existing tokens over new ones. Verify that the # non-penalized token ID is the one that occurs most frequently # in the output, while the penalized token ID is one that has not # yet appeared. assert non_penalized_token_id == most_frequent_token_id assert penalized_token_id not in distinct_sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) def test_sampler_repetition_penalty(device: str, batch_size: int, repetition_penalty: float): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing output. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) sampling_metadata.repetition_penalties = _create_penalty_tensor( batch_size, repetition_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() prompt_tokens = sampling_metadata.prompt_token_ids[ batch_idx][:].tolist() output_tokens = sampling_metadata.output_token_ids[batch_idx] if repetition_penalty > 1.0: # If `repetition_penalty` > 1.0, verify that the non-penalized # token ID has not been seen before, while the penalized token ID # exists either in the prompt or the output. assert (non_penalized_token_id not in prompt_tokens and non_penalized_token_id not in output_tokens) assert (penalized_token_id in prompt_tokens or penalized_token_id in output_tokens) elif repetition_penalty < 1.0: # If `repetition_penalty` < 1.0, verify that the penalized # token ID has not been seen before, while the non-penalized # token ID exists either in the prompt or the output. assert (penalized_token_id not in prompt_tokens and penalized_token_id not in output_tokens) assert (non_penalized_token_id in prompt_tokens or non_penalized_token_id in output_tokens) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("min_p", [0.0, 0.1]) def test_sampler_min_p(device: str, batch_size: int, min_p: float): """ Tests that when min_p is applied, tokens with probability below min_p * max_prob are masked with -inf. """ torch.set_default_device(device) fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) # Create one dominant token per batch for i in range(batch_size): fake_logits[i, 0] = 10.0 # High logit for first token fake_logits[i, 1:] = 1e-2 # Others remain low sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) # Configure min_p parameters sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device) sampler = Sampler() logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p) logits = logits.cpu() for batch_idx in range(batch_size): for token_id in range(VOCAB_SIZE): if token_id == 0: # Dominant token should always be unmasked assert logits[batch_idx][token_id] != -float("inf") else: if min_p > 0.0: # Non-dominant tokens should be masked when min_p > 0 assert logits[batch_idx][token_id] == -float("inf") else: # No masking when min_p is 0 assert logits[batch_idx][token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("bias_value", [-0.1, 1.2]) def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing output. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) sampling_metadata.logit_bias = _create_logit_bias( batch_size=batch_size, vocab_size=VOCAB_SIZE, bias_value=bias_value, ) sampler = Sampler() logits = sampler.apply_logits_bias(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] biased_index = min(batch_idx, VOCAB_SIZE - 1) for token_id in range(VOCAB_SIZE): if biased_index == token_id: assert logits_for_req[token_id] == pytest.approx(bias_value + 1e-2) else: assert logits_for_req[token_id] == pytest.approx(1e-2) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) def test_sampler_allowed_token_ids(device: str, batch_size: int, num_allowed_token_ids: int): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing output. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) mask = _create_allowed_token_ids( batch_size=batch_size, vocab_size=VOCAB_SIZE, num_allowed_token_ids=num_allowed_token_ids, device=device, ) sampling_metadata.allowed_token_ids_mask = mask sampler = Sampler() logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] if batch_idx % 2 == 1: assert torch.all(logits_for_req != -float("inf")) continue for token_id in range(VOCAB_SIZE): start = min(batch_idx, VOCAB_SIZE - 1) end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1) if token_id >= start and token_id < end: assert logits_for_req[token_id] == -float( "inf"), f"{batch_idx}, {token_id}" else: assert logits_for_req[token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)]) def test_sampler_bad_words(device: str, batch_size: int, bad_words_lengths: list[tuple[int]]): """ Test to verify that when the bad words restriction is present, tokens are penalized based on their match with the bad words. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids( batch_size, VOCAB_SIZE, bad_words_lengths) bad_words_last_tokens = _update_output_token_ids_for_bad_words( sampling_metadata, VOCAB_SIZE) sampler = Sampler() logits = sampler.apply_bad_words(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] for token_id in range(VOCAB_SIZE): if (batch_idx in bad_words_last_tokens and token_id in bad_words_last_tokens[batch_idx]): assert logits_for_req[token_id] == -float("inf") else: assert logits_for_req[token_id] != -float("inf")