Skip to content

vllm.v1.kv_offload.spec

CanonicalKVCacheRef dataclass

Per-layer (or group of layers) reference to a specific (by index) CanonicalKVCacheTensor and records the un-padded page size used by that layer.

Source code in vllm/v1/kv_offload/spec.py
@dataclass
class CanonicalKVCacheRef:
    """
    Per-layer (or group of layers) reference to a specific (by index)
    CanonicalKVCacheTensor and records the un-padded page size used by that layer.
    """

    # Index into the list of CanonicalKVCacheTensor objects
    tensor_idx: int
    # The un-padded page size per block in bytes
    page_size_bytes: int

CanonicalKVCacheTensor dataclass

A canonicalized KV cache tensor whose first dimension is num_blocks.

For attention backends where the raw tensor has num_blocks at a non-leading physical dimension (e.g. FlashAttention's (2, num_blocks, ...) layout), the tensor is split so that each resulting CanonicalKVCacheTensor starts with (num_blocks, ...).

Source code in vllm/v1/kv_offload/spec.py
@dataclass
class CanonicalKVCacheTensor:
    """
    A canonicalized KV cache tensor whose first dimension is num_blocks.

    For attention backends where the raw tensor has num_blocks at a
    non-leading physical dimension (e.g. FlashAttention's
    (2, num_blocks, ...) layout), the tensor is split so that each
    resulting CanonicalKVCacheTensor starts with (num_blocks, ...).
    """

    # The KV cache tensor with shape (num_blocks, ...)
    tensor: torch.Tensor
    # The (possibly padded) page size per block in bytes
    page_size_bytes: int

CanonicalKVCaches dataclass

Canonicalized block-level representation of the KV caches.

Composed of
  • Unique list of KV cache data tensors, each with shape (num_blocks, page_size_in_bytes) and int8 dtype.
  • Per-group data references of the tensors. i.e. how each KV cache group maps to the tensors.
Source code in vllm/v1/kv_offload/spec.py
@dataclass
class CanonicalKVCaches:
    """
    Canonicalized block-level representation of the KV caches.

    Composed of:
        - Unique list of KV cache data tensors,
          each with shape (num_blocks, page_size_in_bytes) and int8 dtype.
        - Per-group data references of the tensors.
          i.e. how each KV cache group maps to the tensors.
    """

    # Ordered list of unique block tensors, each with shape
    # (num_blocks, ...).
    tensors: list[CanonicalKVCacheTensor]
    # Per-KV-cache-group list of data references that map each layer
    # in the group to the appropriate entry in the tensors list.
    group_data_refs: list[list[CanonicalKVCacheRef]]

OffloadingSpec

Bases: ABC

Spec for an offloading connector

Source code in vllm/v1/kv_offload/spec.py
class OffloadingSpec(ABC):
    """Spec for an offloading connector"""

    def __init__(self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"):
        logger.warning(
            "Initializing OffloadingSpec. This API is experimental and "
            "subject to change in the future as we iterate the design."
        )
        self.vllm_config = vllm_config
        self.kv_cache_config = kv_cache_config

        kv_transfer_config = vllm_config.kv_transfer_config
        assert kv_transfer_config is not None
        self.extra_config = kv_transfer_config.kv_connector_extra_config

        # block size used by vLLM for hashing request tokens for the sake
        # of enabling prefix caching
        self.hash_block_size = vllm_config.cache_config.block_size
        # gpu block size per group
        self.gpu_block_size: tuple[int, ...] = tuple(
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        )

        for block_size in self.gpu_block_size:
            assert block_size % self.hash_block_size == 0

        # offloaded_block_size / gpu_block_size
        self.block_size_factor: int = 1

        offloaded_block_size = self.extra_config.get("block_size")
        if offloaded_block_size is not None:
            offloaded_block_size_int = int(offloaded_block_size)
            gpu_block_sizes = set(self.gpu_block_size)
            assert len(gpu_block_sizes) == 1, (
                "If 'block_size' is specified in kv_connector_extra_config, "
                "there must be at least one KV cache group, "
                "and all groups must have the same block size."
            )
            gpu_block_size = gpu_block_sizes.pop()

            assert offloaded_block_size_int % gpu_block_size == 0
            self.block_size_factor = offloaded_block_size_int // gpu_block_size

    @abstractmethod
    def get_manager(self) -> OffloadingManager:
        """
        Get an OffloadingManager that will be used
        by the scheduler-side offloading connector to track
        offloaded blocks and manage evictions.
        """
        pass

    @abstractmethod
    def get_handlers(
        self, kv_caches: CanonicalKVCaches
    ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
        """
        Get offloading handlers along with their respective src and dst types.

        Args:
            kv_caches: Canonicalized KV caches.

        Yields:
            Tuples of (src_type, dst_type, offloading_handler).
        """
        pass

get_handlers abstractmethod

Get offloading handlers along with their respective src and dst types.

Parameters:

Name Type Description Default
kv_caches CanonicalKVCaches

Canonicalized KV caches.

required

Yields:

Type Description
tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]

Tuples of (src_type, dst_type, offloading_handler).

Source code in vllm/v1/kv_offload/spec.py
@abstractmethod
def get_handlers(
    self, kv_caches: CanonicalKVCaches
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
    """
    Get offloading handlers along with their respective src and dst types.

    Args:
        kv_caches: Canonicalized KV caches.

    Yields:
        Tuples of (src_type, dst_type, offloading_handler).
    """
    pass

get_manager abstractmethod

get_manager() -> OffloadingManager

Get an OffloadingManager that will be used by the scheduler-side offloading connector to track offloaded blocks and manage evictions.

Source code in vllm/v1/kv_offload/spec.py
@abstractmethod
def get_manager(self) -> OffloadingManager:
    """
    Get an OffloadingManager that will be used
    by the scheduler-side offloading connector to track
    offloaded blocks and manage evictions.
    """
    pass