异步编程+令牌桶算法:批量调用 LLM API
背景
最近在工作中着手模型评测平台的搭建,其中有这么一个场景:需要调用其他部门提供的 LLM API 进行在评测集上跑模型评测,但这个 LLM API 有请求速率限制 - 最多 1 秒调用 2 次。所以我的任务概括来说就是:如何在严格遵守 API 速率请求的情况下,最大提高并发度加快模型评测速度。本文的内容主要记录了对这个任务的尝试,以及最后的解决方案
环境设置
为了方便做演示,我这里采用 fastapi
搭建了 /chat
端点,并且用 slowapi
给 /chat
加上 API 速率限制,代码如下
import random
import asyncio
from fastapi import FastAPI, Response, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Note: the route decorator must be above the limit decorator, not below it
@app.get("/chat")
@limiter.limit("2/second")
async def chat(request: Request, response: Response):
await asyncio.sleep(random.randint(2, 5))
return "hello world"
本地启动 FastAPI 服务之后,http://127.0.0.1:8000/chat
就可以用来模拟带有速率限制的 LLM API 了
首次尝试:多线程 + Requests
显然这是一个 I/O 密集型任务,因此我的第一直觉就是用 Python 的多线程,搭配 requests
库试一下,首先导入必要的库
import time
import requests
from concurrent.futures import ThreadPoolExecutor
然后编写核心函数 call_api
,它自带重试机制,并且会在发生错误的时候等待一段时间之后重试
def call_api(
url: str,
idx: int,
max_retry: int = 5,
retry_sleep: int = 3,
retry_multiplier: int = 2,
) -> dict:
for i in range(max_retry):
print(f"[Request {idx}] Sending {time.time() - START:.2f}")
try:
response = requests.get(url)
response.raise_for_status()
if response:
print(
f"[Request {idx}] SUCCESS {time.time() - START:.2f}: receive {response.json()}"
)
return response.json()
retry_sleep = 3
except Exception:
print(
f"[Request {idx}] FAIL {time.time() - START}: retry in {retry_sleep} seconds"
)
time.sleep(retry_sleep)
retry_sleep = retry_sleep**retry_multiplier
raise Exception("Hit max retry limit")
然后创建线程池,大小设置为 2(考虑到 API 速率限制是 1 秒 2 次),并假设需要请求 10 次
print(f"=== {time.time() - START:.2f} Start All Tasks ===")
with ThreadPoolExecutor(2) as executor:
urls = ["http://127.0.0.1:8000/chat"] * 10
indices = list(range(len(urls)))
results = executor.map(call_api, urls, indices)
print(f"=== {time.time() - START:.2f} Finish All Tasks ===")
下面是某一次的运行日志
=== 0.00 Start All Tasks ===
[Request 0] Sending 0.00
[Request 1] Sending 0.00
[Request 1] SUCCESS 3.01: receive hello world
[Request 2] Sending 3.01
[Request 0] SUCCESS 5.01: receive hello world
[Request 3] Sending 5.01
[Request 3] SUCCESS 7.01: receive hello world
[Request 4] Sending 7.01
[Request 2] SUCCESS 7.01: receive hello world
[Request 5] Sending 7.02
[Request 5] SUCCESS 10.01: receive hello world
[Request 6] Sending 10.01
[Request 4] SUCCESS 12.01: receive hello world
[Request 7] Sending 12.01
[Request 6] SUCCESS 13.01: receive hello world
[Request 8] Sending 13.01
[Request 7] SUCCESS 15.02: receive hello world
[Request 9] Sending 15.02
[Request 8] SUCCESS 17.02: receive hello world
[Request 9] SUCCESS 20.02: receive hello world
=== 20.02 Finish All Tasks ===
启动之后,我发现完成 10 个请求的时间远远没有达到理论速率。10 个请求按照 2 秒钟 1 个来算,再考虑得到一次模型回复最多等待 5 秒,那么理论上 10 秒钟就应该得到答案了,怎么也不会到 20 秒
不信邪的我用 hyperfine 进行了多次测试,下面是 hyperfine
的输出
Benchmark 1: uv run multi_thread.py
Time (mean ± σ): 18.688 s ± 2.007 s [User: 0.118 s, System: 0.028 s]
Range (min … max): 15.198 s … 22.165 s 10 runs
可以看到,确实从来没有一次可以在 10 秒钟之内完成 10 个请求的处理,平均需要 18 秒左右,那么问题出在哪里了呢?从日志输出中,我发现了一丝端倪
[Request 0] Sending 0.00
[Request 1] Sending 0.00
[Request 2] Sending 3.01
[Request 3] Sending 5.01
一开始发送 2 个请求没问题,但是后续 2 个请求并没有立刻发送,而是分别等待了 3 秒和 5 秒后才继续发送请求,而我预期的是每一秒都要有 2 个请求发送
🤔 如果你对异步编程有一定的了解的话,不难想明白其中的道理:requests 是同步的,call_api
函数里面发送请求后线程只能阻塞等待 API 返回结果
结论:不要用多线程 + requests
并发 HTTP 请求
第二次尝试:aiohttp 异步编程
想明白问题之后,我有了解决问题的思路:用异步框架发送 HTTP 请求。在一番搜索之后,我找到了 aiohttp
,官方文档 写得很简洁很好上手(如果你理解异步编程的话)
首先改造一下 call_api
函数为异步的,主要是替换 requests.get
方法
async def call_api(
session: aiohttp.ClientSession,
url: str,
idx: int,
max_retry: int = 5,
retry_sleep: int = 3,
retry_multiplier: int = 2,
) -> dict:
for i in range(max_retry):
print(f"[Request {idx}] Sending {time.time() - START:.2f}")
try:
async with session.get(url) as response:
response.raise_for_status()
result = await response.json()
print(
f"[Request {idx}] SUCCESS {time.time() - START:.2f}: receive {result}"
)
return result
retry_sleep = 3
except Exception:
print(
f"[Request {idx}] FAIL {time.time() - START}: retry in {retry_sleep} seconds"
)
await asyncio.sleep(retry_sleep)
retry_sleep = retry_sleep**retry_multiplier
raise Exception("Hit max retry limit")
然后写一个 main
函数并启动即可,这里用的是 asyncio.gather
方法
async def main():
async with aiohttp.ClientSession() as session:
urls = ["http://127.0.0.1:8000/chat"] * 10
results = asyncio.gather(
*[call_api(session, url, i) for i, url in enumerate(urls)]
)
await results
if __name__ == "__main__":
print(f"=== {time.time() - START:.2f} Start All Tasks ===")
asyncio.run(main())
print(f"=== {time.time() - START:.2f} Finish All Tasks ===")
启动之后你可以看到如下的日志
=== 0.00 Start All Tasks ===
[Request 0] Sending 0.00
[Request 1] Sending 0.00
[Request 2] Sending 0.00
[Request 3] Sending 0.00
[Request 4] Sending 0.00
[Request 5] Sending 0.00
[Request 6] Sending 0.00
[Request 7] Sending 0.00
[Request 8] Sending 0.00
[Request 9] Sending 0.00
[Request 4] FAIL 0.00676417350769043: retry in 3 seconds
[Request 9] FAIL 0.0068662166595458984: retry in 3 seconds
[Request 2] FAIL 0.006894111633300781: retry in 3 seconds
[Request 7] FAIL 0.006909370422363281: retry in 3 seconds
[Request 5] FAIL 0.006922006607055664: retry in 3 seconds
[Request 3] FAIL 0.0069332122802734375: retry in 3 seconds
[Request 8] FAIL 0.006942033767700195: retry in 3 seconds
[Request 6] FAIL 0.006951332092285156: retry in 3 seconds
...
🤔 启动了 10 个请求,有 8 个是失败的,显然这里失败是因为速率限制。但起码目前看来,发送 HTTP 请求的时候不阻塞了,只是需要找到一种方式限制 HTTP 请求的速率
结论:用异步框架(如 aiohttp
)并发你的 HTTP 请求
最终版本:aiohttp 异步编程 + 令牌桶算法
1 秒钟 2 个请求,那么令牌补充的速率就是 2,令牌桶容量也是 2
令牌桶,顾名思义它是一个包含令牌的桶。每当你要做些啥的时候,就需要从桶里取令牌。令牌会以某些规则生成,常见的比如每秒钟生成多少个令牌
观察上面的描述不难得出令牌桶算法的几个核心要素
rate
:令牌生成的速度,单位是个/秒tokens
:当前持有的令牌数量capaciry
: 令牌桶的大小last_update
:上一次补充令牌的时间need
:每次操作需要消耗的令牌数量,一般为 1
算法流程如下
- 初始化
rate = ...
,结合业务场景,设置为指定的速率。在本文的场景rate=2
capacity = ...
,设置为令牌桶大小。在本文的场景capacity=2
tokens = 0
,设置为 0,用于严格限制速率last_update = time.monotonic()
- 算法流程
- 根据当前的时间和上一次补充令牌的时间
last_update
的差值补充令牌tokens
,注意更新的时候不能超过令牌上限 - 检查当前令牌的数量
- 如果补充后的令牌足够使用(
tokens >= need
),那么就扣掉相应的令牌need
,然后执行操作 - 否则进行等待,显然需要等待的时间是
(need - tokens) / rate
。等待完成后又回到第 1 步
- 如果补充后的令牌足够使用(
- 根据当前的时间和上一次补充令牌的时间
- 变量
tokens
和变量last_update
是共享变量,可能被多个协程改动。所以需要加锁 - 用asyncio.Lock()
。 time.monotonic
和time.time
不一样,只有前者保证了单调性,适合用于计算时间差值,所以这里用这个
根据算法流程不难写出如下代码
class RateLimiter:
def __init__(self, rate: int, capacity: int):
self.rate = rate
self.tokens = 0
self.capacity = capacity
self.lock = asyncio.Lock()
self.last_update = time.monotonic()
async def acquire(self, need: int = 1):
while True:
async with self.lock:
now = time.monotonic()
elapsed = now - self.last_update
# Add new tokens
new_tokens = int(elapsed * self.rate)
if new_tokens > 0:
self.tokens = min(self.capacity, self.tokens + new_tokens)
self.last_update = now
if self.tokens >= need:
self.tokens -= need
return
deficit = need - self.tokens
wait_time = deficit / self.rate
await asyncio.sleep(wait_time)
再对前面的代码稍加修改
为了简洁,重复的地方用 ...
省略表示
async def call_api(
...
rate_limiter: RateLimiter,
...
) -> dict:
...
try:
await rate_limiter.acquire()
...
except Exception:
...
...
async def main():
async with aiohttp.ClientSession() as session:
urls = ["http://127.0.0.1:8000/chat"] * 10
rate_limiter = RateLimiter(rate=2, capacity=2)
results = asyncio.gather(
*[call_api(session, rate_limiter, url, i) for i, url in enumerate(urls)]
)
await results
if __name__ == "__main__":
print(f"=== {time.time() - START:.2f} Start All Tasks ===")
asyncio.run(main())
print(f"=== {time.time() - START:.2f} Finish All Tasks ===")
最后,用 hyperfine
测一下,结果如下所示
Benchmark 1: uv run async_concurrency.py
Time (mean ± σ): 9.766 s ± 0.552 s [User: 0.173 s, System: 0.032 s]
Range (min … max): 8.715 s … 10.240 s 10 runs
可以看到,改进很明显,基本就达到了理论速度
总结
上面基本还原了我在处理这个问题时候的解决思路。作为一名算法工程师,日常工作用 Python 多进程处理数据比较多,本次为了处理模型评测性能问题用到的异步编程、令牌桶算法对我来说都是比较新的东西,整体下来还是挺有趣的,也学到了新的知识点 :0