目录

异步编程+令牌桶算法:批量调用 LLM API

最近在工作中着手模型评测平台的搭建,其中有这么一个场景:需要调用其他部门提供的 LLM API 进行在评测集上跑模型评测,但这个 LLM API 有请求速率限制 - 最多 1 秒调用 2 次。所以我的任务概括来说就是:如何在严格遵守 API 速率请求的情况下,最大提高并发度加快模型评测速度。本文的内容主要记录了对这个任务的尝试,以及最后的解决方案

为了方便做演示,我这里采用 fastapi 搭建了 /chat 端点,并且用 slowapi/chat 加上 API 速率限制,代码如下

import random
import asyncio
from fastapi import FastAPI, Response, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

limiter = Limiter(key_func=get_remote_address)

app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)


# Note: the route decorator must be above the limit decorator, not below it
@app.get("/chat")
@limiter.limit("2/second")
async def chat(request: Request, response: Response):
    await asyncio.sleep(random.randint(2, 5))
    return "hello world"

本地启动 FastAPI 服务之后,http://127.0.0.1:8000/chat 就可以用来模拟带有速率限制的 LLM API 了

显然这是一个 I/O 密集型任务,因此我的第一直觉就是用 Python 的多线程,搭配 requests 库试一下,首先导入必要的库

import time
import requests
from concurrent.futures import ThreadPoolExecutor

然后编写核心函数 call_api,它自带重试机制,并且会在发生错误的时候等待一段时间之后重试

def call_api(
    url: str,
    idx: int,
    max_retry: int = 5,
    retry_sleep: int = 3,
    retry_multiplier: int = 2,
) -> dict:
    for i in range(max_retry):
        print(f"[Request {idx}] Sending {time.time() - START:.2f}")
        try:
            response = requests.get(url)
            response.raise_for_status()
            if response:
                print(
                    f"[Request {idx}] SUCCESS {time.time() - START:.2f}: receive {response.json()}"
                )
                return response.json()
            retry_sleep = 3
        except Exception:
            print(
                f"[Request {idx}] FAIL {time.time() - START}: retry in {retry_sleep} seconds"
            )
            time.sleep(retry_sleep)
            retry_sleep = retry_sleep**retry_multiplier
    raise Exception("Hit max retry limit")

然后创建线程池,大小设置为 2(考虑到 API 速率限制是 1 秒 2 次),并假设需要请求 10 次

print(f"=== {time.time() - START:.2f} Start All Tasks ===")

with ThreadPoolExecutor(2) as executor:
    urls = ["http://127.0.0.1:8000/chat"] * 10
    indices = list(range(len(urls)))
    results = executor.map(call_api, urls, indices)

print(f"=== {time.time() - START:.2f} Finish All Tasks ===")

下面是某一次的运行日志

=== 0.00 Start All Tasks ===
[Request 0] Sending 0.00
[Request 1] Sending 0.00
[Request 1] SUCCESS 3.01: receive hello world
[Request 2] Sending 3.01
[Request 0] SUCCESS 5.01: receive hello world
[Request 3] Sending 5.01
[Request 3] SUCCESS 7.01: receive hello world
[Request 4] Sending 7.01
[Request 2] SUCCESS 7.01: receive hello world
[Request 5] Sending 7.02
[Request 5] SUCCESS 10.01: receive hello world
[Request 6] Sending 10.01
[Request 4] SUCCESS 12.01: receive hello world
[Request 7] Sending 12.01
[Request 6] SUCCESS 13.01: receive hello world
[Request 8] Sending 13.01
[Request 7] SUCCESS 15.02: receive hello world
[Request 9] Sending 15.02
[Request 8] SUCCESS 17.02: receive hello world
[Request 9] SUCCESS 20.02: receive hello world
=== 20.02 Finish All Tasks ===

启动之后,我发现完成 10 个请求的时间远远没有达到理论速率。10 个请求按照 2 秒钟 1 个来算,再考虑得到一次模型回复最多等待 5 秒,那么理论上 10 秒钟就应该得到答案了,怎么也不会到 20 秒

不信邪的我用 hyperfine 进行了多次测试,下面是 hyperfine 的输出

Benchmark 1: uv run multi_thread.py
  Time (mean ± σ):     18.688 s ±  2.007 s    [User: 0.118 s, System: 0.028 s]
  Range (min  max):   15.198 s  22.165 s    10 runs

可以看到,确实从来没有一次可以在 10 秒钟之内完成 10 个请求的处理,平均需要 18 秒左右,那么问题出在哪里了呢?从日志输出中,我发现了一丝端倪

[Request 0] Sending 0.00
[Request 1] Sending 0.00
[Request 2] Sending 3.01
[Request 3] Sending 5.01

一开始发送 2 个请求没问题,但是后续 2 个请求并没有立刻发送,而是分别等待了 3 秒和 5 秒后才继续发送请求,而我预期的是每一秒都要有 2 个请求发送

🤔 如果你对异步编程有一定的了解的话,不难想明白其中的道理:requests 是同步的,call_api 函数里面发送请求后线程只能阻塞等待 API 返回结果

Note

结论:不要用多线程 + requests 并发 HTTP 请求

想明白问题之后,我有了解决问题的思路:用异步框架发送 HTTP 请求。在一番搜索之后,我找到了 aiohttp官方文档 写得很简洁很好上手(如果你理解异步编程的话)

首先改造一下 call_api 函数为异步的,主要是替换 requests.get 方法

async def call_api(
    session: aiohttp.ClientSession,
    url: str,
    idx: int,
    max_retry: int = 5,
    retry_sleep: int = 3,
    retry_multiplier: int = 2,
) -> dict:
    for i in range(max_retry):
        print(f"[Request {idx}] Sending {time.time() - START:.2f}")
        try:
            async with session.get(url) as response:
                response.raise_for_status()
                result = await response.json()
                print(
                    f"[Request {idx}] SUCCESS {time.time() - START:.2f}: receive {result}"
                )
                return result
            retry_sleep = 3
        except Exception:
            print(
                f"[Request {idx}] FAIL {time.time() - START}: retry in {retry_sleep} seconds"
            )
            await asyncio.sleep(retry_sleep)
            retry_sleep = retry_sleep**retry_multiplier
    raise Exception("Hit max retry limit")

然后写一个 main 函数并启动即可,这里用的是 asyncio.gather 方法

async def main():
    async with aiohttp.ClientSession() as session:
        urls = ["http://127.0.0.1:8000/chat"] * 10
        results = asyncio.gather(
            *[call_api(session, url, i) for i, url in enumerate(urls)]
        )
        await results


if __name__ == "__main__":
    print(f"=== {time.time() - START:.2f} Start All Tasks ===")
    asyncio.run(main())
    print(f"=== {time.time() - START:.2f} Finish All Tasks ===")

启动之后你可以看到如下的日志

=== 0.00 Start All Tasks ===
[Request 0] Sending 0.00
[Request 1] Sending 0.00
[Request 2] Sending 0.00
[Request 3] Sending 0.00
[Request 4] Sending 0.00
[Request 5] Sending 0.00
[Request 6] Sending 0.00
[Request 7] Sending 0.00
[Request 8] Sending 0.00
[Request 9] Sending 0.00
[Request 4] FAIL 0.00676417350769043: retry in 3 seconds
[Request 9] FAIL 0.0068662166595458984: retry in 3 seconds
[Request 2] FAIL 0.006894111633300781: retry in 3 seconds
[Request 7] FAIL 0.006909370422363281: retry in 3 seconds
[Request 5] FAIL 0.006922006607055664: retry in 3 seconds
[Request 3] FAIL 0.0069332122802734375: retry in 3 seconds
[Request 8] FAIL 0.006942033767700195: retry in 3 seconds
[Request 6] FAIL 0.006951332092285156: retry in 3 seconds
...

🤔 启动了 10 个请求,有 8 个是失败的,显然这里失败是因为速率限制。但起码目前看来,发送 HTTP 请求的时候不阻塞了,只是需要找到一种方式限制 HTTP 请求的速率

Note

结论:用异步框架(如 aiohttp)并发你的 HTTP 请求

Tip

1 秒钟 2 个请求,那么令牌补充的速率就是 2,令牌桶容量也是 2

令牌桶,顾名思义它是一个包含令牌的桶。每当你要做些啥的时候,就需要从桶里取令牌。令牌会以某些规则生成,常见的比如每秒钟生成多少个令牌

观察上面的描述不难得出令牌桶算法的几个核心要素

  • rate:令牌生成的速度,单位是个/秒
  • tokens:当前持有的令牌数量
  • capaciry: 令牌桶的大小
  • last_update:上一次补充令牌的时间
  • need:每次操作需要消耗的令牌数量,一般为 1

算法流程如下

  • 初始化
    • rate = ...,结合业务场景,设置为指定的速率。在本文的场景 rate=2
    • capacity = ...,设置为令牌桶大小。在本文的场景 capacity=2
    • tokens = 0,设置为 0,用于严格限制速率
    • last_update = time.monotonic()
  • 算法流程
    1. 根据当前的时间和上一次补充令牌的时间 last_update 的差值补充令牌 tokens,注意更新的时候不能超过令牌上限
    2. 检查当前令牌的数量
      • 如果补充后的令牌足够使用(tokens >= need),那么就扣掉相应的令牌 need,然后执行操作
      • 否则进行等待,显然需要等待的时间是 (need - tokens) / rate。等待完成后又回到第 1 步
Tip
  1. 变量 tokens 和变量 last_update 是共享变量,可能被多个协程改动。所以需要加锁 - 用 asyncio.Lock()
  2. time.monotonictime.time 不一样,只有前者保证了单调性,适合用于计算时间差值,所以这里用这个

根据算法流程不难写出如下代码

class RateLimiter:
    def __init__(self, rate: int, capacity: int):
        self.rate = rate
        self.tokens = 0
        self.capacity = capacity
        self.lock = asyncio.Lock()
        self.last_update = time.monotonic()

    async def acquire(self, need: int = 1):
        while True:
            async with self.lock:
                now = time.monotonic()
                elapsed = now - self.last_update

                # Add new tokens
                new_tokens = int(elapsed * self.rate)
                if new_tokens > 0:
                    self.tokens = min(self.capacity, self.tokens + new_tokens)
                    self.last_update = now

                if self.tokens >= need:
                    self.tokens -= need
                    return

                deficit = need - self.tokens
                wait_time = deficit / self.rate

            await asyncio.sleep(wait_time)

再对前面的代码稍加修改

Info

为了简洁,重复的地方用 ... 省略表示

async def call_api(
    ...
    rate_limiter: RateLimiter,
    ...
) -> dict:
    ...
        try:
            await rate_limiter.acquire()
            ...
        except Exception:
            ...
    ...

async def main():
    async with aiohttp.ClientSession() as session:
        urls = ["http://127.0.0.1:8000/chat"] * 10
        rate_limiter = RateLimiter(rate=2, capacity=2)
        results = asyncio.gather(
            *[call_api(session, rate_limiter, url, i) for i, url in enumerate(urls)]
        )
        await results

        
if __name__ == "__main__":
    print(f"=== {time.time() - START:.2f} Start All Tasks ===")
    asyncio.run(main())
    print(f"=== {time.time() - START:.2f} Finish All Tasks ===")

最后,用 hyperfine 测一下,结果如下所示

Benchmark 1: uv run async_concurrency.py
  Time (mean ± σ):      9.766 s ±  0.552 s    [User: 0.173 s, System: 0.032 s]
  Range (min  max):    8.715 s  10.240 s    10 runs

可以看到,改进很明显,基本就达到了理论速度

上面基本还原了我在处理这个问题时候的解决思路。作为一名算法工程师,日常工作用 Python 多进程处理数据比较多,本次为了处理模型评测性能问题用到的异步编程、令牌桶算法对我来说都是比较新的东西,整体下来还是挺有趣的,也学到了新的知识点 :0