mirror of https://github.com/vllm-project/vllm.git
114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any, Optional, Union
|
|
|
|
import msgspec
|
|
import zmq
|
|
from msgspec.msgpack import Decoder
|
|
|
|
|
|
#
|
|
# Types copied from vllm.distributed.kv_events
|
|
#
|
|
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False):
|
|
ts: float
|
|
events: list[Any]
|
|
|
|
|
|
class KVCacheEvent(
|
|
msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True
|
|
):
|
|
"""Base class for all KV cache-related events"""
|
|
|
|
|
|
class BlockStored(KVCacheEvent):
|
|
block_hashes: list[int]
|
|
parent_block_hash: Optional[int]
|
|
token_ids: list[int]
|
|
block_size: int
|
|
lora_id: Optional[int]
|
|
|
|
|
|
class BlockRemoved(KVCacheEvent):
|
|
block_hashes: list[int]
|
|
|
|
|
|
class AllBlocksCleared(KVCacheEvent):
|
|
pass
|
|
|
|
|
|
class KVEventBatch(EventBatch):
|
|
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
|
|
|
|
|
def process_event(event_batch):
|
|
print(f"Received event batch at {event_batch.ts}:")
|
|
for event in event_batch.events:
|
|
print(f" - {event}")
|
|
|
|
|
|
def main():
|
|
decoder = Decoder(type=KVEventBatch)
|
|
last_seq = -1
|
|
|
|
context = zmq.Context()
|
|
|
|
# Set up the main subscription socket
|
|
sub = context.socket(zmq.SUB)
|
|
sub.connect("tcp://localhost:5557")
|
|
topic = "kv-events"
|
|
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
|
|
|
|
# Initialize replay socket
|
|
replay = context.socket(zmq.REQ)
|
|
replay.connect("tcp://localhost:5558")
|
|
poller = zmq.Poller()
|
|
poller.register(replay, zmq.POLLIN)
|
|
|
|
print("Listening for KV cache events on topic:", topic)
|
|
|
|
while True:
|
|
try:
|
|
if sub.poll(50):
|
|
_, seq_bytes, payload = sub.recv_multipart()
|
|
seq = int.from_bytes(seq_bytes, "big")
|
|
|
|
if last_seq >= 0 and seq > last_seq + 1:
|
|
missed = seq - last_seq - 1
|
|
print(
|
|
f"Missed {missed} messages (last: {last_seq}, current: {seq})"
|
|
)
|
|
|
|
replay.send((last_seq + 1).to_bytes(8, "big"))
|
|
|
|
while poller.poll(timeout=200):
|
|
seq_bytes, replay_payload = replay.recv_multipart()
|
|
if not replay_payload:
|
|
# End of replay marker is sent as an empty frame
|
|
# for the payload
|
|
break
|
|
|
|
replay_seq = int.from_bytes(seq_bytes, "big")
|
|
|
|
if replay_seq > last_seq:
|
|
event_batch = decoder.decode(replay_payload)
|
|
process_event(event_batch)
|
|
last_seq = replay_seq
|
|
if replay_seq >= seq - 1:
|
|
break
|
|
|
|
event_batch = decoder.decode(payload)
|
|
process_event(event_batch)
|
|
|
|
# ... do other periodic work or check for shutdown ...
|
|
|
|
except KeyboardInterrupt:
|
|
print("Interrupted")
|
|
break
|
|
except Exception as e:
|
|
print("Error decoding message:", e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|