# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Compare the with and without prefix caching.""" import copy from typing import Optional import pytest import torch from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) def make_request(request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, prompt_logprobs: Optional[int] = None, cache_salt: Optional[str] = None): if mm_positions is None: multi_modal_inputs = None else: multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), eos_token_id=100, lora_request=None, cache_salt=cache_salt, ) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False), ) ], ) def make_kv_cache_config_hybrid_model(block_size: int, num_blocks: int) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], FullAttentionSpec(block_size, 1, 1, torch.float32, False), ), KVCacheGroupSpec( ["layer2"], SlidingWindowSpec(block_size, 1, 1, torch.float32, False, sliding_window=2 * block_size), ), KVCacheGroupSpec( ["layer3"], SlidingWindowSpec(block_size, 1, 1, torch.float32, False, sliding_window=2 * block_size), ), ], ) @pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) def test_prefill(hash_algo): manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, caching_hash_algo=hash_algo, ) # choose the hash function according to the parameter hash_fn = sha256 if hash_algo == "sha256" else hash # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) assert manager.block_pool.blocks[ block_id].block_hash.block_hash == block_hash assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be # [unallocated (6, 7, 8, 9, 10)] # [unique_req0 (4)] # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([6], ) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. assert manager.block_pool.free_block_queue.num_free_blocks == 6 assert all([ b.ref_cnt == 0 for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) assert len([ b for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) == 6 manager.free(req2) # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, len(computed_blocks.blocks[0]) * 16, computed_blocks) # This block ID order also checks the eviction order. assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], ) assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None def test_prefill_hybrid_model(): block_size = 16 manager = KVCacheManager( make_kv_cache_config_hybrid_model(block_size, 21), max_model_len=8192, enable_caching=True, ) hash_fn = hash # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(block_size)] # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]) # Check full block metadata parent_block_hash = None for length, block_ids in zip((1, 2, 3), ((1, 5, 9), (2, 6, 10), (3, 7, 11))): block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) for block_id in block_ids: assert manager.block_pool.blocks[ block_id].block_hash.block_hash == block_hash assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, 8, 12): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: if block != manager.block_pool.null_block: assert block.ref_cnt == 2 block_hashes = manager.req_to_block_hashes[req1.request_id] manager.free(req0) manager.free(req1) cached_block_hash_to_block_bak = copy.copy( manager.block_pool.cached_block_hash_to_block) def test_partial_request_hit(request_id: str, hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): req = make_request(request_id, common_token_ids + unique_token_ids) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(manager.req_to_block_hashes[req.request_id]) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block[ hash_with_group_id] = cached_block_hash_to_block_bak[ hash_with_group_id] manager.free(req) # Evict the blocks outside sliding window, does not affect the hit length. test_partial_request_hit("2", [ BlockHashWithGroupId(block_hashes[0], 1), BlockHashWithGroupId(block_hashes[0], 2) ], 3) # Evict the first block of full attention, makes total cache miss. test_partial_request_hit("3", [ BlockHashWithGroupId(block_hashes[0], 0), ], 0) # Evict the last block of all layers, reduces the hit length to 2. test_partial_request_hit("4", [ BlockHashWithGroupId(block_hashes[2], 0), BlockHashWithGroupId(block_hashes[2], 1), BlockHashWithGroupId(block_hashes[2], 2), ], 2) # Evict the last block of full attention, reduces the hit length to 2. test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], 2) # Evict the last block of sliding window, reduces the hit length to 2. test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], 2) # Evict the last block of sliding window, reduces the hit length to 2. test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], 2) # Evict different set of blocks for full attention and sliding window makes # total cache miss. # The cache hit length of full attention is 1 * block_size. # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers have different hit length. test_partial_request_hit("8", [ BlockHashWithGroupId(block_hashes[2], 0), BlockHashWithGroupId(block_hashes[0], 1), BlockHashWithGroupId(block_hashes[0], 2), ], 0) def test_prefill_plp(): '''Test prefill with APC and some prompt logprobs (plp) requests. 1. Schedule plp request and validate APC block allocation 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, ) # the default hash function is hash hash_fn = hash # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] # Request #0 is a prompt logprobs request # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 0 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) assert manager.block_pool.blocks[ block_id].block_hash.block_hash == block_hash assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Request #1 is a non-prompt-logprobs request: # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be # [unallocated (6, 7, 8, 9, 10)] # [unique_req0 (4)] # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 0 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes assert block_ids != ([1, 2, 3, 4], ) # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. for block_id in block_ids[0]: assert manager.block_pool.blocks[block_id].ref_cnt == 1 manager.free(req2) def test_decode(): manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, ) # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-1].block_hash is None # Append slots with allocating a new block. req0.num_computed_tokens = 59 # 9 tokens to fill the previous block, and 10 tokens to fill # the preallocated block. for _ in range(9 + 10): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 1 assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-2].block_hash is not None assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-1].block_hash is None def test_evict(): manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, ) last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 assert manager.block_pool.free_block_queue.num_free_blocks == 1 manager.free(req0) manager.free(req1) assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([10], ) assert manager.block_pool.free_block_queue.num_free_blocks == 7 def test_hash_block_correct_reuse(): """ This tests when a previously cached block is reused as a new block, its hash metadata should be correctly reset. """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, ) # Allocate 1 block and cache it. num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 # Deallocate the block. manager.free(req) # Allocate a new block that's not full, make sure hash info on the # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 assert manager.block_pool.blocks[blocks.blocks[0] [0].block_id].block_hash is None def test_computed_blocks_not_evicted(): """ Test that the computed blocks are not evicted when getting new blocks for a request if there are any other free blocks. """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, ) # Allocate a block and cache it. num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 # Free the blocks. manager.free(req0) manager.free(req1) # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 def test_basic_prefix_caching_disabled(): """ This tests that the prefix caching is disabled. """ block_size = 4 manager = KVCacheManager( make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, ) req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert not blocks @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_cache_blocks(hash_fn): """ This is a unit test that tests the correctness of the _cache_full_blocks function of KVCacheManager. """ block_size = 4 block_pool = BlockPool( num_gpu_blocks=5, enable_caching=True, ) # Req: # Block 0: [0, 1, 2, 3] # Block 1: [4, 5, 6, 7] # Block 2: [8, 9, 10, 11] # Block 3: [12, 13] req = make_request("0", list(range(14))) # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, hash_fn=hash_fn, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) # Test that blocks that don't start from the beginning are cached correctly. blocks += [KVCacheBlock(block_id=2)] block_pool.cache_full_blocks( request=req, blocks=blocks, block_hashes=block_hashes, num_cached_blocks=2, num_full_blocks=3, block_size=block_size, hash_fn=hash_fn, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None def test_cache_blocks_multi_group(): """ This tests that blocks are cached correctly for different kv cache groups. """ block_size = 4 block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) # Req: # Block 0/4: [0, 1, 2, 3] # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] req = make_request("0", list(range(14))) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, hash_fn=hash, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 assert len(block_hashes) == 2 assert all([block.block_hash is not None for block in blocks]) # Cache the blocks for group 1. blocks = [KVCacheBlock(block_id=i) for i in range(3)] block_pool.cache_full_blocks( request=req, blocks=blocks, block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=3, block_size=block_size, hash_fn=hash, kv_cache_group_id=1, ) assert len(block_pool.cached_block_hash_to_block) == 5 assert len(block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Block hash 0: hit for group 0 and 1 # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 assert block_pool.get_cached_block(block_hashes[0], kv_cache_group_ids=[0]) is not None assert block_pool.get_cached_block(block_hashes[1], kv_cache_group_ids=[0]) is not None assert block_pool.get_cached_block(block_hashes[2], kv_cache_group_ids=[0]) is None assert block_pool.get_cached_block(block_hashes[0], kv_cache_group_ids=[1]) is not None assert block_pool.get_cached_block(block_hashes[1], kv_cache_group_ids=[1]) is not None assert block_pool.get_cached_block(block_hashes[2], kv_cache_group_ids=[1]) is not None assert block_pool.get_cached_block(block_hashes[0], kv_cache_group_ids=[0, 1]) is not None assert block_pool.get_cached_block(block_hashes[1], kv_cache_group_ids=[0, 1]) is not None assert block_pool.get_cached_block(block_hashes[2], kv_cache_group_ids=[0, 1]) is None def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) # [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1] common_token_ids = list(range(10)) + [-1] * 6 common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2 common_token_ids += [-1] * 16 common_mm_positions = [ PlaceholderRange(offset=11, length=10), PlaceholderRange(offset=30, length=18), ] common_mm_hashes = ["aaa", "bbb"] # A unique image plus some text tokens. unique_token_ids = [-1] * 7 + [100] * 4 all_token_ids = common_token_ids + unique_token_ids mm_positions = common_mm_positions + [ PlaceholderRange(offset=48, length=7) ] mm_hashes = common_mm_hashes + ["ccc"] req0 = make_request("0", all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 assert block_hashes[3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 all_token_ids = common_token_ids + unique_token_ids mm_positions = common_mm_positions + [ PlaceholderRange(offset=48, length=7) ] mm_hashes = common_mm_hashes + ["ccc"] req1 = make_request("1", all_token_ids, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 def test_cache_key_salting(): """ This tests that cache salts are applied during hashing and the cache is separated cache as expected. """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 req0 = make_request("0", token_ids, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None assert block_hashes[2].extra_keys is None blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 # Now one more block that should not have extra keys. assert len(block_hashes) == 4 assert block_hashes[3].extra_keys is None # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 req1 = make_request("1", token_ids, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * block_size # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req2.request_id] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ This is a unit test that tests the correctness of the allocate_slots when there is not enough free blocks. Specifically, when a request has computed blocks but cannot be allocated due to not enough free blocks, the computed blocks should not be touched. """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks) block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks) block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[ req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) assert {block.ref_cnt for block in block_part1[:3]} == {1} assert {block.ref_cnt for block in block_part1[3:]} == {0} # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, len(computed_blocks.blocks[0]) * 16, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks) is None # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. assert {block.ref_cnt for block in block_part1[3:]} == {0} def test_reset_prefix_cache(): manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == ([5], ) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() assert manager.block_pool.cached_block_hash_to_block # Free the blocks. manager.free(req0) manager.free(req1) assert manager.reset_prefix_cache() assert not manager.block_pool.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) def test_prefix_cache_stats_disabled(): """Test that prefix_cache_stats is None when log_stats is False.""" manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, log_stats=False, # Disable logging stats ) assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None assert manager.prefix_cache_stats is None @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) def test_kv_cache_events(blocks_to_cache: int): block_size = 16 num_blocks = blocks_to_cache + 1 # Allocate Blocks # Should see a single block stored event with a blocks_to_cache number of # block hashes # take_events should reset the kv_event_queue manager = KVCacheManager( make_kv_cache_config(block_size, num_blocks), max_model_len=8192, enable_caching=True, enable_kv_cache_events=True, ) num_tokens = block_size * blocks_to_cache req0 = make_request("0", list(range(num_tokens))) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() block = events[-1] assert (len(block.block_hashes) == blocks_to_cache == len( manager.block_pool.cached_block_hash_to_block)) assert len(block.token_ids) == block.block_size * len(block.block_hashes) assert len(manager.block_pool.kv_event_queue) == 0 stored_block_hash = block.block_hashes # Remove blocks and send another request # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) req1 = make_request("1", list(range(num_tokens))) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() for blocks in events[:-1]: assert blocks.block_hashes[0] in stored_block_hash assert len(events) == blocks_to_cache + 1 assert (isinstance(events[-2], BlockRemoved)) assert (len(events[-1].block_hashes) == blocks_to_cache == len( manager.block_pool.cached_block_hash_to_block)) # All Blocks Cleared # Should see a single all blocks cleared event manager.free(req1) manager.reset_prefix_cache() events = manager.take_events() assert isinstance(events[-1], AllBlocksCleared) assert len(manager.block_pool.cached_block_hash_to_block) == 0 def test_eagle_enabled_removes_last_block(): """Verify Eagle does NOT remove blocks when request length is divisible by block size.""" block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, num_blocks=10), max_model_len=8192, enable_caching=True, use_eagle=True, ) # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) req = make_request("divisible_request", token_ids) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled req_eagle = make_request("eagle_divisible", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks # 2. drop last matched block → 1 remaining block assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # 16 tokens def test_eagle_with_partial_blocks(): """Test Eagle behavior with requests containing partial blocks.""" block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, num_blocks=10), max_model_len=8192, enable_caching=True, use_eagle=True, ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) req = make_request("partial_block_test", token_ids) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks) manager.free(req) # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size def test_eagle_with_sliding_window(): """Test Eagle behavior with sliding window.""" block_size = 16 sliding_window_spec = SlidingWindowSpec( block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, sliding_window=block_size, use_mla=False, ) manager = KVCacheManager( KVCacheConfig( num_blocks=10, kv_cache_tensors=[], kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], ), max_model_len=8192, enable_caching=True, use_eagle=True, ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) req = make_request("partial_block_test", token_ids) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks) # record the block hash of the first block in the request for later use block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # Evict the first block in the request assert manager.block_pool.get_cached_block( block_hash_first_block, kv_cache_group_ids=[0]) is not None manager.block_pool.cached_block_hash_to_block.pop( BlockHashWithGroupId(block_hash_first_block, 0)) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0