본문 바로가기
AI/LLM

[LLM] LMCache 코드 분석 - store, retrieve

by 잔디🌿 2025. 12. 31.

    이제 논문을 봤으니 코드를 분석해보아야한다

     

    vLLM과 LMCache 연결

    def _init_lmcache_engine(
        lmcache_config: LMCacheEngineConfig,
        vllm_config: "VllmConfig",
        role: str,
    ) -> LMCacheEngine:
        """Initialize the LMCache engine by the given model config and parallel
        config. This function will check the environment variable
        `LMCACHE_CONFIG_FILE` to load the configuration file. If that environment
        variable is not set, this function will return None.
    
        :param lmcache_config: The LMCache configuration.
        :type lmcache_config: LMCacheEngineConfig
        :param vllm_config: The vLLM configuration.
        :type vllm_config: VllmConfig
    
        :return: The initialized LMCache engine
        :rtype: LMCacheEngine
        """
        if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME):
            return curr_engine
    
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        cache_config = vllm_config.cache_config
    
        assert isinstance(lmcache_config, LMCacheEngineConfig), (
            "LMCache v1 configuration is should be passed."
        )
    
        kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)
    
        use_mla = mla_enabled(model_config)
        if use_mla and (
            lmcache_config.remote_serde != "naive"
            and lmcache_config.remote_serde is not None
        ):
            raise ValueError("MLA only works with naive serde mode..")
    
        # construct kv shape (for mem pool)
        num_layer = model_config.get_num_layers(parallel_config)
        num_draft_layers = _calculate_draft_layers(vllm_config, model_config)
        num_layer += num_draft_layers
        chunk_size = lmcache_config.chunk_size
        # this is per gpu
        num_kv_head = model_config.get_num_kv_heads(parallel_config)
        head_size = model_config.get_head_size()
        kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
        logger.info(
            f"num_layer: {num_layer}, chunk_size: {chunk_size}, "
            f"num_kv_head (per gpu): {num_kv_head}, head_size: {head_size}, "
            f"hidden_dim (D) for KV (per gpu): {num_kv_head * head_size}, "
            f"use mla: {use_mla}, kv shape: {kv_shape}, num_draft_layers:{num_draft_layers}"
        )
    
        # Change current device.
        if current_platform.is_cuda_alike():
            logger.info("CUDA device is available. Using CUDA for LMCache engine.")
            torch_dev = torch.cuda
            dev_name = "cuda"
        elif current_platform.is_xpu():
            logger.info("XPU device is available. Using XPU for LMCache engine.")
            torch_dev = torch.xpu
            dev_name = "xpu"
        else:
            raise RuntimeError("Unsupported device platform for LMCache engine.")
    
        num_gpus = torch_dev.device_count()
        local_rank = parallel_config.rank % num_gpus
        torch_dev.set_device(local_rank)
        device = torch.device(f"{dev_name}:{local_rank}")
        metadata = LMCacheEngineMetadata(
            model_config.model,
            parallel_config.world_size,
            parallel_config.rank,
            "vllm",
            kv_dtype,
            kv_shape,
            use_mla,
            role,
        )
    
        use_gpu = need_gpu_interm_buffer(lmcache_config)
        vllm_gpu_connector: Optional[GPUConnectorInterface]
    
        # Validate MLA with layerwise configurations
        if use_mla and lmcache_config.use_layerwise and lmcache_config.enable_blending:
            raise ValueError(
                "We haven't supported MLA with Cacheblend yet. Please disable blending."
            )
    
        # When use_mla is True, num_kv_head is 1
        hidden_dim_size = num_kv_head * head_size
        if role == "scheduler":
            vllm_gpu_connector = None
            # Create a dummy tpg object with broadcast and broadcast_object methods
            tpg = SimpleNamespace()
            tpg.broadcast = lambda tensor, src: tensor
            tpg.broadcast_object = lambda obj, src: obj
        elif lmcache_config.use_layerwise:
            if lmcache_config.enable_blending:
                # Use layerwise connector for blending
                vllm_gpu_connector = VLLMBufferLayerwiseGPUConnector(
                    hidden_dim_size,
                    num_layer,
                    use_gpu=use_gpu,
                    chunk_size=chunk_size,
                    dtype=kv_dtype,
                    device=device,
                )
            else:
                vllm_gpu_connector = VLLMPagedMemLayerwiseGPUConnector(
                    hidden_dim_size,
                    num_layer,
                    use_gpu=use_gpu,
                    chunk_size=chunk_size,
                    dtype=kv_dtype,
                    device=device,
                    use_mla=use_mla,
                )
            tpg = get_tp_group()
        else:
            if current_platform.is_cuda_alike():
                connector_cls = VLLMPagedMemGPUConnectorV2
            elif current_platform.is_xpu():
                connector_cls = VLLMPagedMemXPUConnectorV2
            else:
                raise RuntimeError("No supported connector found for the current platform.")
    
            vllm_gpu_connector = connector_cls(
                hidden_dim_size,
                num_layer,
                use_gpu=use_gpu,
                chunk_size=chunk_size,
                dtype=kv_dtype,
                device=device,
                use_mla=use_mla,
            )
            tpg = get_tp_group()
        engine = LMCacheEngineBuilder.get_or_create(
            ENGINE_NAME,
            lmcache_config,
            metadata,
            vllm_gpu_connector,
            tpg.broadcast,
            tpg.broadcast_object,
        )
        if role == "scheduler" and lmcache_config.enable_scheduler_bypass_lookup:
            assert engine.save_only_first_rank or lmcache_config.get_extra_config_value(
                "remote_enable_mla_worker_id_as0", metadata.use_mla
            ), (
                "enable_scheduler_bypass_lookup is only supported with "
                "save_only_first_rank or remote_enable_mla_worker_id_as0"
            )
        return engine

     

    해당 부분은 vllm과 lmcache를 연결하는 부분의 코드이다.

    vllm은 LLM의 동작을 위해서 동작하는 엔진이다. 주로 prefill, decode 등을 담당한다. 

    VLLM에서 kvcache는 하나의 쿼리에 대한 응답으로만 사용된다. 하지만 우리는 만들어진 kvcache을 앞으로도 계속 사용하고자 하는 것이므로, 이를 저장하고 관리하는 엔진이 필요한데 이것이 LMCache이다.

     

    LMCache는 VLLM 프로세스마다 생성된다. vllm의 경우, 스케줄러와 워커 각각 프로세스를 가지는 경우가 있는데, 위 코드를 보면 vllm의 role에 따라서 이에 맞는 LMCache 엔진을 생성한다. 

    뿐만 아니라,

    model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config
    cache_config = vllm_config.cache_config

    위와 같이 다양한 vllm 설정에 따라 LMCache를 설정하고, 이미 엔진이 존재하면 새로 만들지 않고 해당 엔진을 반환하여 싱글톤을 유지하는 등의 기능 등을 한다.

     

    KVCache Store

    cache engine.py의 store()을 보겠다.

    def store(
            self,
            tokens: Optional[Union[torch.Tensor, list[int]]] = None,
            hashes: Optional[List[int]] = None,
            offsets: Optional[List[int]] = None,
            mask: Optional[torch.Tensor] = None,
            **kwargs,
        ) -> None:

    일단 다음과 같은 파라미터를 받는다.

    vllm에서 kv캐시를 저장할때에는 token만 주고, 해시된 토큰의 값을 줄때는 해당 값과 몇개의 토큰 단위로 묶였는지를 알려주는 Offset를 전달한다. 또한 어디서부터 저장할지를 의미하는 mask도 전달한다.

    mask는 FFFTTTTT와 같이 생겨서 특정 지점 이상만 저장하라는 것을 표시한다.

    kwargs에는 gpu에 있는 kv캐시를 꺼내기 위해서 필요한 정보가 저장되어있다.

     

            assert self.gpu_connector is not None, (
                "gpu_connector is required for store operation"
            )

    워커 모드로 생성된 엔진인지 확인한다. 스케줄러인 경우, gpu connector이 생성되지 않기 때문이다.

     

    if self._is_passive():
        logger.debug(f"rank={self.metadata.worker_id} ignore store")
        return

    또한 현재 연결된 vllm의 rank를 나타낸다. 최적화 정책에 따라 특정 rank보다 작으면 저장하지 않는다.

     

    if mask is not None:
                num_to_store_tokens = torch.sum(mask).item()
            elif tokens is not None:
                num_to_store_tokens = len(tokens)
            elif hashes is not None:
                assert offsets is not None, (
                    "Offsets should be set when hashes are provided during store"
                )
                num_to_store_tokens = sum(offsets)
                kwargs["slot_mapping"] = torch.tensor(
                    kwargs["slot_mapping"], dtype=torch.long, device="cuda"
                )

    다음은 저장해야 할 토큰의 수를 구한다.

    mask가 있으면 mask 속 true 의 갯수를 세고, token을 통째로 넘겨줬을떄는 해당 배열의 길이를 센다.

    만약 해시 된 단위로 넘겨주었으면 offsets의 모든 수를 더하여 토큰의 길이를 측정한다. 또한 torch.tensor 함수를 통해서 kwargs["slot_mapping]을 tensor로 만들고, 자료형을 바꾼 후 cuda 디바이스에 올린다.

     

     monitor_req_id = self.stats_monitor.on_store_request(num_to_store_tokens)

    이후 저장 요청이 시작되었다고 통계 모듈에 알려준다.

     

            starts: List[int] = []
            ends: List[int] = []
            keys: List[CacheEngineKey] = []
            memory_objs: List[MemoryObj] = []
    
            offload_time = 0.0
            put_time = 0.0
            tot_kv_size = 0
            tot_token_num = 0
            t = time.perf_counter()

    그 다음 데이터를 저장할 자료구조를 만든다,

    starts,ends는 각 청크의 시작부분, 끝 부분의 인덱스를 나타내고,  각 청크의 Key는 keys에 담고, 그리고 해당 청크를 담을 cpu의 버퍼는 memory_objs에 넣는다.

     

    offload_time는 GPU → CPU 복사 시간
    put_time는 storage backend에 저장하는 시간
    tot_kv_size는 저장한 KV 총 바이트
    tot_token_num은 실제 저장한 토큰 수(= chunk들의 합)을 나타낸다.

     

            request_configs = kwargs.get("request_configs")
            if request_configs is not None and len(request_configs) != 0:
                assert isinstance(request_configs, dict)

    이후 설정값 묶음이 dict 형태인지 확인한 후

     

    for start, end, key in self.token_database.process_tokens(
                tokens,
                hashes,
                offsets,
                mask,
                request_configs=request_configs,
            ):
                assert isinstance(key, CacheEngineKey)
                # Allocate the memory object
                num_tokens = end - start
                kv_shape = self.gpu_connector.get_shape(num_tokens)
                kv_dtype = self.metadata.kv_dtype
    
                # TODO (Jiayi): should be batched in the future
                memory_obj = self.storage_manager.allocate(
                    kv_shape,
                    kv_dtype,
                    busy_loop=self.force_store_wait,
                    fmt=self.fmt,
                )
                if memory_obj is None:
                    logger.warning(
                        "Local cpu memory under pressure so"
                        " choosing to store only "
                        f" {len(memory_objs)}"
                        " total chunks of KV cache."
                    )
                    break

     

    그리고 for문을 통해서 청크 단위로 값을 저장한다.

    token_database.process_tokens 함수에 현재 store 함수에서 받은 값을 넣으면, 값을 청크단위로 쪼개 start, end, key 값을 리턴한다.

    end-start를 통해서 chunk의 길이를 계산한 후, gpu_connector.get_shape를 통해서 shape를 얻는다.

    여기서 shape는 kv 캐시 한 덩어리 모양을 뜻한다. 이 shape는 지금 layerwise인지? layer 수는 뭐인지 등등 모델의 구조에 따라 달라지기 때문에 해당 과정이 필요하다.

     

    이후 self.metadata.kv_dtype를 통해 kv 캐시의 dtype를 가져온다. 이 설정값은 metadata에 들어있다.

     

    memory_obj = self.storage_manager.allocate(
        kv_shape,
        kv_dtype,
        busy_loop=self.force_store_wait,
        fmt=self.fmt,
    )

    그 다음 이 부분에서 cpu 버퍼를 할당한다. (이때 데이터를 넣는 것이 아니라 공간만 확보하는 것)

    memory_obj is None이면 메모리 부족이므로 경고 후 중단한다.

     

                starts.append(start)
                ends.append(end)
                keys.append(key)
                memory_objs.append(memory_obj)
                tot_kv_size += memory_obj.get_size()
                tot_token_num += num_tokens

    이후 시작지점, 끝지점, key를 저장하고, 할당받은 메모리 덩어리도 저장한다. 

    또한 전체 kv 사이즈, 토큰 사이즈를 갱신한다.

     

    self.gpu_connector.batched_from_gpu(memory_objs, starts, ends, **kwargs)
    offload_time += time.perf_counter() - t

    이후 gpu에서 cpu로 kv캐시를 복사한다.

    starts, ends(청크의 시작과 끝), **kwargs를 통해서 gpu에 있는 해당 kv 캐시를 찾고, 미리 할당해둔 memory_objs에 넣는다.

    이후 오프로드 시간을 갱신한다. 

     

            transfer_spec = kwargs.get("transfer_spec", None)
            # TODO: we implicitly rely on batched_put to call ref_count_down
            # this management should be done in a cleaner way
            self.storage_manager.batched_put(keys, memory_objs, transfer_spec=transfer_spec)

    이후 transfer_spec에 어떻게, 어디에 보낼지 등등의 정보를 저장한 후, kv 캐시를 백엔드에 저장한다.

    keys는 kv캐시의 키이고, memory_objs는 gpu에서 복사된 kv캐시이다.

     

    put_time += time.perf_counter() - t

    이후 백엔드에 저장한 시간을 측정한다.

     

    self.stats_monitor.on_store_finished(monitor_req_id, tot_token_num)

    이제 저장이 끝났으므로, 통계 모듈에 보고한다.

     

    KVCache Retrieve

      def retrieve(
            self,
            tokens: Union[torch.Tensor, list[int]],
            mask: Optional[torch.Tensor] = None,
            **kwargs,
        ) -> torch.Tensor:

    여기서 self는 현재 이를 동작시키는 엔진 정보를 담고 있고, 토큰와 마스크(토큰의 어떤 부분의 캐시를 가지려고 하는지), rmflrh kwargs는 vllm의 위치정보이다.

     

    tot_kv_size = 0
    t = time.perf_counter()

    이번 retirieve에서 가져온 총 바이트수를 저장, 완료되는데까지 걸리는 시간을 저장하기 위한 변수를 생성한다.

     

    if mask is not None:
        num_required_tokens = torch.sum(mask).item()
    else:
        num_required_tokens = len(tokens)
    monitor_req_id = self.stats_monitor.on_retrieve_request(num_required_tokens)

    이후, 요청되는 토큰의 길이를 계산한다. 만약 마스크가 존재하면, 마스크의 합계로 계산하고 없으면 토큰의 길이로 계산한다.

    이후 stats_monitor을 통해서 이번 요청에 대한 id값을 얻는다.

     

    ret_mask = torch.zeros(len(tokens), dtype=torch.bool, device="cpu")

    이건 kv캐시를 가져온 토큰을 표시하기 위한 마스크이다. 초기값은 모두 false이다.

     

            if not self._is_passive():
                if self.async_loading:
                    reordered_chunks, tot_kv_size = self._async_process_tokens_internal(  # noqa: E501
                        tokens,
                        mask,
                        ret_mask,
                        **kwargs,
                    )
                else:
                    reordered_chunks, tot_kv_size = self._process_tokens_internal(
                        tokens,
                        mask,
                        ret_mask,
                        **kwargs,
                    )
            if self.save_only_first_rank:
                with torch.cuda.stream(self.broadcast_stream):
                    self._broadcast_or_receive_memory_objs(
                        reordered_chunks,
                        ret_mask,
                    )
    
                # if self.gpu_connector has load_stream, self.broadcast_stream is equals
                # to self.gpu_connector.load_stream, the broadcast and to_gpu operation
                # will execute sequentially within the stream.
                # if self.gpu_connector does not have load_stream, self.broadcast_stream
                # is created by torch.cuda.Stream(), we need to synchronize broadcast
                # operation, and then process to_cpu operation.
                if not hasattr(self.gpu_connector, "load_stream"):
                    self.broadcast_stream.synchronize()

    이제 kv 캐시를 gpu에서 찾아본다.

    lmcache에서는 분산된 환경으로 인해 gpu가 여러개일 수 있다. 이때 모든 gpu가 kv캐시를 찾는 과정을 수행한다면 이는 비효율적일 수 있다.

     

    만약 active이면 현재 gpu가 직접 Lookup(gpu에 캐시가 있는지) 하고, 나머지 gpu는 passive로, broadcast를 통해서 다른 gpu에 요청해서 가져온다.

    이 때 동기, 비동기에 따라 코드가 갈린다.

    만약 동기라면, _process_tokens_internal을 통해서 캐시를 순차적으로 가져온다

    하지만 비동기라면 async_process_tokens_internal을 통해서 캐시를 비동기적으로 가져온다.

     

    그 다음 이 작업에 수행되는 cuda작업을 broadcast_stream에 올린 후, broadcast_or_receive_memory_objs를 통해서 first rank(active)인 경우에는 현재 가져온 캐시를 다른 작업에 전달하고, 나머지(passive)는 first rank가 보낸 캐시를 받는다.

     

    이 때 gpu load stream과 broadcast_stream이 독립적이라면 위의 broadcast 과정이 다 끝나지 않아 kv캐시가 준비되지 않았음에도 불구하고, gpu load가 호출될 수 있다. 이를 방지하기 위해 synchronize()를 해준다.

     

    if len(reordered_chunks) > 0:
                _, memory_objs, starts, ends = zip(*reordered_chunks, strict=False)
                self.gpu_connector.batched_to_gpu(
                    list(memory_objs), list(starts), list(ends), **kwargs
                )

    만약 hit 해서 가져온 청크가 있다면, 이를 gpu에 올린다. 이때 batched_to_gpu를 통해서 한번에 처리한다.

     

            for key, memory_obj, _, _ in reordered_chunks:
                if self.remove_after_retrieve and not self._is_passive():
                    assert self.storage_manager is not None
                    self.storage_manager.remove(key)
                memory_obj.ref_count_down()
    
            onload_time = time.perf_counter() - t
    
            retrieved_tokens = torch.sum(ret_mask)
            self.stats_monitor.on_retrieve_finished(monitor_req_id, retrieved_tokens)

    그리고, 가져온 캐시를 삭제하는 과정을 거친다.

    self.remove_after_retrieve는 한번 retrieve 되었으면 삭제하라는 뜻이다. 주로 일회용 캐시에 쓰인다. 이 때 self.is_passive()를 통해서 active인 상태에서만 삭제가 가능하도록 한다.

    이후, 조회 시간을 잰 뒤, 모니터에 종료되었음을 알린다.

     

    return ret_mask

    이후 mask를 리턴하여 어떤 토큰에 대한 캐시가 로드되었는지 표시한다.