diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index ab29f20..dd06c18 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -5,6 +5,7 @@ import asyncio import itertools import logging import time +from collections import deque from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from threading import Lock @@ -12,6 +13,7 @@ from typing import Literal, Callable, Any from urllib.parse import urlparse import httpx +import tiktoken from docutranslate.agents.thinking.thinking_factory import get_thinking_mode from docutranslate.logger import global_logger @@ -32,10 +34,10 @@ class AgentResultError(ValueError): class PartialAgentResultError(ValueError): """一个特殊的异常,用于表示结果不完整但包含了部分成功的数据,以便触发重试。该错误不计入总错误数""" - def __init__(self, message, partial_result: dict,append_prompt:str=None): + def __init__(self, message, partial_result: dict, append_prompt: str = None): super().__init__(message) self.partial_result = partial_result - self.append_prompt=append_prompt + self.append_prompt = append_prompt @dataclass(kw_only=True) @@ -46,11 +48,13 @@ class AgentConfig: model_id: str temperature: float = 0.7 concurrent: int = 30 - timeout: int = 1200 # 单位(秒),这个值是httpx.TimeOut中read的值,并非总的超时时间 + timeout: int = 1200 thinking: ThinkingMode = "disable" retry: int = 2 system_proxy_enable: bool = False - force_json: bool = False # 应输出json格式时强制ai输出json + force_json: bool = False + rpm: int | None = None # 每分钟请求数限制 + tpm: int | None = None # 每分钟Token数限制 class TotalErrorCounter: @@ -71,7 +75,6 @@ class TotalErrorCounter: return self.count > self.max_errors_count -# 仅使用多线程时用以计数 class PromptsCounter: def __init__(self, total: int, logger: logging.Logger): self.lock = Lock() @@ -85,21 +88,101 @@ class PromptsCounter: self.logger.info(f"多线程-已完成:{self.count}/{self.total}") +# --- 新增 RateLimiter 类 --- +class RateLimiter: + """ + 基于滑动窗口的速率限制器,支持 RPM 和 TPM 控制。 + 同时支持 Async 和 Sync 调用。 + """ + + def __init__(self, rpm: int | None, tpm: int | None): + self.rpm = rpm + self.tpm = tpm + # 双端队列存储 (timestamp, value),value对于RPM是1,对于TPM是token数量 + self.request_timestamps = deque() + self.token_timestamps = deque() + self.lock = Lock() # 用于同步模式和保护共享数据 + + def _cleanup_window(self, now: float): + """清理60秒窗口之前的数据""" + window_start = now - 60.0 + + while self.request_timestamps and self.request_timestamps[0] <= window_start: + self.request_timestamps.popleft() + + while self.token_timestamps and self.token_timestamps[0][0] <= window_start: + self.token_timestamps.popleft() + + def _check_and_get_wait_time(self, tokens: int) -> float: + """检查是否满足限制,返回需要等待的秒数。如果不需等待返回 0""" + now = time.time() + self._cleanup_window(now) + + wait_time = 0.0 + + # Check RPM + if self.rpm and len(self.request_timestamps) >= self.rpm: + # 取最早的一条记录,计算还需要等待多久才能腾出位置 + earliest = self.request_timestamps[0] + wait_time = max(wait_time, 60 - (now - earliest)) + + # Check TPM + if self.tpm: + current_tokens = sum(t[1] for t in self.token_timestamps) + if current_tokens + tokens > self.tpm: + # 稍微复杂点:需要移除足够多的旧token才能放入新token + # 这里做一个简化估算:如果超限,等到最早的记录过期 + if self.token_timestamps: + earliest = self.token_timestamps[0][0] + wait_time = max(wait_time, 60 - (now - earliest)) + else: + # 这种情况理论上不应该发生,除非单次请求超过了TPM上限 + # 如果单次超过上限,强制等待1秒(防止死循环)并允许通过(或者抛异常,这里选择允许) + pass + + return wait_time + + def _record_usage(self, tokens: int): + """记录使用量""" + now = time.time() + if self.rpm is not None: + self.request_timestamps.append(now) + if self.tpm is not None: + self.token_timestamps.append((now, tokens)) + + async def acquire_async(self, tokens: int = 0): + """异步等待配额""" + if self.rpm is None and self.tpm is None: + return + + while True: + with self.lock: + wait_time = self._check_and_get_wait_time(tokens) + if wait_time <= 0: + self._record_usage(tokens) + return + + # 释放锁后等待,避免阻塞其他协程/线程的检查 + # 添加一点点缓冲时间,避免刚唤醒时毫秒级误差导致再次等待 + await asyncio.sleep(wait_time + 0.1) + + def acquire_sync(self, tokens: int = 0): + """同步等待配额(线程阻塞)""" + if self.rpm is None and self.tpm is None: + return + + while True: + with self.lock: + wait_time = self._check_and_get_wait_time(tokens) + if wait_time <= 0: + self._record_usage(tokens) + return + + time.sleep(wait_time + 0.1) + + def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: - """ - 从API响应中提取token信息 - - 支持多种response格式: - 1. 格式1: usage.input_tokens_details.cached_tokens 和 usage.output_tokens_details.reasoning_tokens - 2. 格式2: usage.prompt_tokens_details.cached_tokens - 3. 格式3: usage.prompt_cache_hit_tokens 和 usage.completion_tokens_details.reasoning_tokens - - Args: - response_data: API响应数据 - - Returns: - tuple: (input_tokens, cached_tokens, output_tokens, reasoning_tokens) - """ + """(保持原样) 从API响应中提取token信息""" if "usage" not in response_data: return 0, 0, 0, 0 @@ -107,43 +190,34 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) - # 初始化token详细统计 cached_tokens = 0 reasoning_tokens = 0 try: - # 尝试从不同格式获取cached_tokens - # 格式1: input_tokens_details.cached_tokens if ( - "input_tokens_details" in usage - and "cached_tokens" in usage["input_tokens_details"] + "input_tokens_details" in usage + and "cached_tokens" in usage["input_tokens_details"] ): cached_tokens = usage["input_tokens_details"]["cached_tokens"] - # 格式2: prompt_tokens_details.cached_tokens elif ( - "prompt_tokens_details" in usage - and "cached_tokens" in usage["prompt_tokens_details"] + "prompt_tokens_details" in usage + and "cached_tokens" in usage["prompt_tokens_details"] ): cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] - # 格式3: prompt_cache_hit_tokens (直接在usage下) elif "prompt_cache_hit_tokens" in usage: cached_tokens = usage["prompt_cache_hit_tokens"] - # 尝试从不同格式获取reasoning_tokens - # 格式1: output_tokens_details.reasoning_tokens if ( - "output_tokens_details" in usage - and "reasoning_tokens" in usage["output_tokens_details"] + "output_tokens_details" in usage + and "reasoning_tokens" in usage["output_tokens_details"] ): reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"] - # 格式2: completion_tokens_details.reasoning_tokens elif ( - "completion_tokens_details" in usage - and "reasoning_tokens" in usage["completion_tokens_details"] + "completion_tokens_details" in usage + and "reasoning_tokens" in usage["completion_tokens_details"] ): reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] return input_tokens, cached_tokens, output_tokens, reasoning_tokens - except TypeError as e: - # print(f"获取token失败,跳过token计数:{e.__repr__()}") + except TypeError: return -1, -1, -1, -1 @@ -158,11 +232,11 @@ class TokenCounter: self.logger = logger def add( - self, - input_tokens: int, - cached_tokens: int, - output_tokens: int, - reasoning_tokens: int, + self, + input_tokens: int, + cached_tokens: int, + output_tokens: int, + reasoning_tokens: int, ): with self.lock: self.input_tokens += input_tokens @@ -170,10 +244,6 @@ class TokenCounter: self.output_tokens += output_tokens self.reasoning_tokens += reasoning_tokens self.total_tokens += input_tokens + output_tokens - # self.logger.debug( - # f"Token使用统计 - 输入: {self.input_tokens}(含cached: {self.cached_tokens}), " - # f"输出: {self.output_tokens}(含reasoning: {self.reasoning_tokens}), 总计: {self.total_tokens}" - # ) def get_stats(self): with self.lock: @@ -202,7 +272,6 @@ ErrorResultHandlerType = Callable[[str, logging.Logger], Any] class Agent: def __init__(self, config: AgentConfig): - self.baseurl = config.base_url.strip() if self.baseurl.endswith("/"): self.baseurl = self.baseurl[:-1] @@ -216,18 +285,38 @@ class Agent: self.thinking = config.thinking self.logger = config.logger self.total_error_counter = TotalErrorCounter(logger=self.logger) - # 新增:用于统计最终未解决的错误 self.unresolved_error_lock = Lock() self.unresolved_error_count = 0 - # 新增:用于统计token使用情况 self.token_counter = TokenCounter(logger=self.logger) - self.retry = config.retry - self.system_proxy_enable = config.system_proxy_enable + # 新增:初始化速率限制器 + self.rate_limiter = RateLimiter(rpm=config.rpm, tpm=config.tpm) + # 新增:初始化 encoding 用于估算 + self.encoding = self._get_encoding_for_model(self.model_id) + + def _get_encoding_for_model(self, model_name: str): + """获取 tiktoken encoding,如果失败则使用 cl100k_base 兜底""" + try: + return tiktoken.encoding_for_model(model_name) + except KeyError: + # 对于未知模型或自定义模型ID,使用 GPT-4 的默认编码器 + return tiktoken.get_encoding("cl100k_base") + + def _estimate_tokens(self, text: str) -> int: + """估算文本的 Token 数量""" + if not text: + return 0 + try: + # 这是一个近似值,不包含特殊 token 格式的开销,但用于限流足够了 + return len(self.encoding.encode(text)) + except Exception: + # 极端兜底:每4个字符算1个token + return len(text) // 4 + def _add_thinking_mode(self, data: dict): - thinking_mode_result=get_thinking_mode(self.domain,data.get("model")) + thinking_mode_result = get_thinking_mode(self.domain, data.get("model")) if thinking_mode_result is None: return field_thinking, val_enable, val_disable = thinking_mode_result @@ -236,9 +325,8 @@ class Agent: elif self.thinking == "disable": data[field_thinking] = val_disable - def _prepare_request_data( - self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False + self, prompt: str, system_prompt: str, temperature=None, top_p=0.9, json_format=False ): if temperature is None: temperature = self.temperature @@ -262,27 +350,32 @@ class Agent: return headers, data async def send_async( - self, - client: httpx.AsyncClient, - prompt: str, - system_prompt: None | str = None, - retry=True, - retry_count=0, - force_json=False, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, - best_partial_result: dict | None = None, + self, + client: httpx.AsyncClient, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + force_json=False, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, + best_partial_result: dict | None = None, ) -> Any: if system_prompt is None: system_prompt = self.system_prompt if pre_send_handler: system_prompt, prompt = pre_send_handler(system_prompt, prompt) - # print(f"system_prompt:\n{system_prompt}") - # print(f"【测试】prompt:\n{prompt}") + + # 新增:速率限制检查 + # 计算估算的 tokens (system + user) + estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(prompt) + # 等待配额 + await self.rate_limiter.acquire_async(tokens=estimated_tokens) + headers, data = self._prepare_request_data(prompt, system_prompt, json_format=force_json) should_retry = False - is_hard_error = False # 新增标志,用于区分是否为硬错误 + is_hard_error = False current_partial_result = None input_tokens = 0 output_tokens = 0 @@ -294,18 +387,14 @@ class Agent: headers=headers, timeout=self.timeout, ) - # print(f"【测试】json:\n{data}") response.raise_for_status() - # print(f"【测试】resp:\n{response.json()}") result = response.json()["choices"][0]["message"]["content"] - # print(f"【测试】\nprompt:\n{prompt}\nresp:\n{result}") - # 获取token使用情况 + response_data = response.json() input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( extract_token_info(response_data) ) - # 更新token计数器 self.token_counter.add( input_tokens, cached_tokens, output_tokens, reasoning_tokens ) @@ -313,7 +402,6 @@ class Agent: if retry_count > 0: self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。") - # print(f"result:=============================================================\n{result}\n================\n") return ( result if result_handler is None @@ -323,23 +411,23 @@ class Agent: except AgentResultError as e: self.logger.error(f"AI返回结果有误: {e}") should_retry = True - # 专门捕获部分翻译错误(软错误) except PartialAgentResultError as e: - # print(f"【测试】\nprompt:\n{prompt}\nresp:\n{result}") self.logger.error(f"收到部分返回结果,将尝试重试: {e}") current_partial_result = e.partial_result should_retry = True if e.append_prompt: - prompt+=e.append_prompt - # is_hard_error 保持 False + prompt += e.append_prompt - # 捕获硬错误 except httpx.HTTPStatusError as e: self.logger.error( f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}" ) should_retry = True is_hard_error = True + # 如果是因为 Rate Limit (429) 错误,最好在这里多睡一会儿,虽然我们有了本地 Limiter + if e.response.status_code == 429: + await asyncio.sleep(5) + except httpx.RequestError as e: self.logger.error(f"AI请求连接错误 (async): {repr(e)}") should_retry = True @@ -353,12 +441,10 @@ class Agent: best_partial_result = current_partial_result if should_retry and retry and retry_count < self.retry: - # 仅在硬错误时才增加总错误计数 if is_hard_error: if retry_count == 0: if self.total_error_counter.add(): self.logger.error("错误次数过多,已达到上限,不再重试。") - # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 with self.unresolved_error_lock: self.unresolved_error_count += 1 return ( @@ -372,7 +458,6 @@ class Agent: ) elif self.total_error_counter.reach_limit(): self.logger.error("错误次数过多,已达到上限,不再为该请求重试。") - # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 with self.unresolved_error_lock: self.unresolved_error_count += 1 return ( @@ -386,7 +471,8 @@ class Agent: ) self.logger.info(f"正在重试第 {retry_count + 1}/{self.retry} 次...") - await asyncio.sleep(0.5) + # 指数退避 + await asyncio.sleep(0.5 * (2 ** retry_count)) return await self.send_async( client, prompt, @@ -402,7 +488,6 @@ class Agent: else: if should_retry: self.logger.error(f"所有重试均失败,已达到重试次数上限。") - # 新增:当所有重试失败后,增加未解决错误计数 with self.unresolved_error_lock: self.unresolved_error_count += 1 @@ -417,30 +502,32 @@ class Agent: ) async def send_prompts_async( - self, - prompts: list[str], - system_prompt: str | None = None, - max_concurrent: int | None = None, - force_json=False, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, + self, + prompts: list[str], + system_prompt: str | None = None, + max_concurrent: int | None = None, + force_json=False, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: max_concurrent = ( self.max_concurrent if max_concurrent is None else max_concurrent ) total = len(prompts) + rpm_info = f", RPM:{self.rate_limiter.rpm}" if self.rate_limiter.rpm else "" + tpm_info = f", TPM:{self.rate_limiter.tpm}" if self.rate_limiter.tpm else "" + self.logger.info( - f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{force_json}" + f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent}{rpm_info}{tpm_info},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{force_json}" ) - self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") + self.logger.info(f"预计发送{total}个请求") + self.total_error_counter.max_errors_count = ( - len(prompts) // MAX_REQUESTS_PER_ERROR + len(prompts) // MAX_REQUESTS_PER_ERROR ) - # 新增:在每次批量发送前重置计数器 self.unresolved_error_count = 0 - # 重置token计数器 self.token_counter.reset() count = 0 @@ -450,16 +537,19 @@ class Agent: proxies = get_httpx_proxies(asyn=True) if self.system_proxy_enable else None limits = httpx.Limits( - max_connections=self.max_concurrent * 2, # 为重试和并发预留空间 - max_keepalive_connections=self.max_concurrent, # 保持活动的连接数 + max_connections=self.max_concurrent * 2, + max_keepalive_connections=self.max_concurrent, ) async with httpx.AsyncClient( - trust_env=False, mounts=proxies, verify=False, limits=limits + trust_env=False, mounts=proxies, verify=False, limits=limits ) as client: async def send_with_semaphore(p_text: str): async with semaphore: + # 注意:我们在 semaphore 内部调用 send_async + # send_async 内部会调用 rate_limiter.acquire_async + # 这样可以防止并发过高,同时 rate_limiter 防止频率过快 result = await self.send_async( client=client, prompt=p_text, @@ -480,45 +570,44 @@ class Agent: results = await asyncio.gather(*tasks, return_exceptions=False) - # 新增:在所有任务完成后打印未解决的错误总数 self.logger.info( f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}" ) - # 新增:打印token使用统计 token_stats = self.token_counter.get_stats() - if token_stats["input_tokens"] < 0: - self.logger.info("Token统计失败") - else: - self.logger.info( - f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), " - f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), " - f"总计: {token_stats['total_tokens'] / 1000:.2f}K" - ) + self.logger.info( + f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), " + f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), " + f"总计: {token_stats['total_tokens'] / 1000:.2f}K" + ) return results def send( - self, - client: httpx.Client, - prompt: str, - system_prompt: None | str = None, - retry=True, - retry_count=0, - force_json=False, - pre_send_handler=None, - result_handler=None, - error_result_handler=None, - best_partial_result: dict | None = None, + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + force_json=False, + pre_send_handler=None, + result_handler=None, + error_result_handler=None, + best_partial_result: dict | None = None, ) -> Any: if system_prompt is None: system_prompt = self.system_prompt if pre_send_handler: system_prompt, prompt = pre_send_handler(system_prompt, prompt) + # 新增:同步环境下的速率限制 + estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(prompt) + self.rate_limiter.acquire_sync(tokens=estimated_tokens) + headers, data = self._prepare_request_data(prompt, system_prompt, json_format=force_json) should_retry = False - is_hard_error = False # 新增标志,用于区分是否为硬错误 + is_hard_error = False current_partial_result = None input_tokens = 0 output_tokens = 0 @@ -534,13 +623,11 @@ class Agent: result = response.json()["choices"][0]["message"]["content"] - # 获取token使用情况 response_data = response.json() input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( extract_token_info(response_data) ) - # 更新token计数器 self.token_counter.add( input_tokens, cached_tokens, output_tokens, reasoning_tokens ) @@ -556,20 +643,20 @@ class Agent: except AgentResultError as e: self.logger.error(f"AI返回结果有误: {e}") should_retry = True - # 专门捕获部分翻译错误(软错误) except PartialAgentResultError as e: self.logger.error(f"收到部分翻译结果,将尝试重试: {e}") current_partial_result = e.partial_result should_retry = True - # is_hard_error 保持 False - # 捕获硬错误 except httpx.HTTPStatusError as e: self.logger.error( f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}" ) should_retry = True is_hard_error = True + if e.response.status_code == 429: + time.sleep(5) + except httpx.RequestError as e: self.logger.error(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}") should_retry = True @@ -583,12 +670,10 @@ class Agent: best_partial_result = current_partial_result if should_retry and retry and retry_count < self.retry: - # 仅在硬错误时才增加总错误计数 if is_hard_error: if retry_count == 0: if self.total_error_counter.add(): self.logger.error("错误次数过多,已达到上限,不再重试。") - # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 with self.unresolved_error_lock: self.unresolved_error_count += 1 return ( @@ -602,7 +687,6 @@ class Agent: ) elif self.total_error_counter.reach_limit(): self.logger.error("错误次数过多,已达到上限,不再为该请求重试。") - # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 with self.unresolved_error_lock: self.unresolved_error_count += 1 return ( @@ -616,7 +700,7 @@ class Agent: ) self.logger.info(f"正在重试第 {retry_count + 1}/{self.retry} 次...") - time.sleep(0.5) + time.sleep(0.5 * (2 ** retry_count)) return self.send( client, prompt, @@ -632,7 +716,6 @@ class Agent: else: if should_retry: self.logger.error(f"所有重试均失败,已达到重试次数上限。") - # 新增:当所有重试失败后,增加未解决错误计数 with self.unresolved_error_lock: self.unresolved_error_count += 1 @@ -647,16 +730,17 @@ class Agent: ) def _send_prompt_count( - self, - client: httpx.Client, - prompt: str, - system_prompt: None | str, - force_json, - count: PromptsCounter, - pre_send_handler, - result_handler, - error_result_handler + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str, + force_json, + count: PromptsCounter, + pre_send_handler, + result_handler, + error_result_handler ) -> Any: + # 该方法在 ThreadPoolExecutor 中运行 result = self.send( client, prompt, @@ -670,27 +754,28 @@ class Agent: return result def send_prompts( - self, - prompts: list[str], - system_prompt: str | None = None, - json_format=False, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, + self, + prompts: list[str], + system_prompt: str | None = None, + json_format=False, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: + rpm_info = f", RPM:{self.rate_limiter.rpm}" if self.rate_limiter.rpm else "" + tpm_info = f", TPM:{self.rate_limiter.tpm}" if self.rate_limiter.tpm else "" + self.logger.info( - f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{json_format}" + f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent}{rpm_info}{tpm_info},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{json_format}" ) self.logger.info( - f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}" + f"预计发送{len(prompts)}个请求" ) self.total_error_counter.max_errors_count = ( - len(prompts) // MAX_REQUESTS_PER_ERROR + len(prompts) // MAX_REQUESTS_PER_ERROR ) - # 新增:在每次批量发送前重置计数器 self.unresolved_error_count = 0 - # 重置token计数器 self.token_counter.reset() counter = PromptsCounter(len(prompts), self.logger) @@ -702,12 +787,13 @@ class Agent: result_handlers = itertools.repeat(result_handler, len(prompts)) error_result_handlers = itertools.repeat(error_result_handler, len(prompts)) limits = httpx.Limits( - max_connections=self.max_concurrent * 2, # 允许连接复用 - max_keepalive_connections=self.max_concurrent, # 保持活跃连接 + max_connections=self.max_concurrent * 2, + max_keepalive_connections=self.max_concurrent, ) proxies = get_httpx_proxies(asyn=False) if self.system_proxy_enable else None + with httpx.Client( - trust_env=False, mounts=proxies, verify=False, limits=limits + trust_env=False, mounts=proxies, verify=False, limits=limits ) as client: clients = itertools.repeat(client, len(prompts)) with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: @@ -724,24 +810,15 @@ class Agent: ) output_list = list(results_iterator) - # 新增:在所有任务完成后打印未解决的错误总数 self.logger.info( f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}" ) - # 新增:打印token使用统计 token_stats = self.token_counter.get_stats() - if token_stats["input_tokens"] < 0: - self.logger.info("Token统计失败") - else: - self.logger.info( - f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), " - f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), " - f"总计: {token_stats['total_tokens'] / 1000:.2f}K" - ) + self.logger.info( + f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), " + f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), " + f"总计: {token_stats['total_tokens'] / 1000:.2f}K" + ) - return output_list - - -if __name__ == "__main__": - pass + return output_list \ No newline at end of file diff --git a/docutranslate/app.py b/docutranslate/app.py index 06c1fcb..5647d96 100644 --- a/docutranslate/app.py +++ b/docutranslate/app.py @@ -337,6 +337,12 @@ class GlossaryAgentConfigPayload(BaseModel): force_json: bool = Field( default=False, description="强制Agent输出JSON格式的术语表。" ) + rpm: Optional[int] = Field( + default=None, description="RPM限制 (Requests Per Minute)" + ) + tpm: Optional[int] = Field( + default=None, description="TPM限制 (Tokens Per Minute)" + ) # 1. 定义所有工作流共享的基础参数 @@ -411,6 +417,12 @@ class BaseWorkflowParams(BaseModel): force_json: bool = Field( default=False, description="应输出json格式时强制ai输出json" ) + rpm: Optional[int] = Field( + default=None, description="RPM限制 (Requests Per Minute)" + ) + tpm: Optional[int] = Field( + default=None, description="TPM限制 (Tokens Per Minute)" + ) @model_validator(mode="before") @classmethod @@ -723,6 +735,8 @@ class TranslateServiceRequest(BaseModel): "mineru_token": "your-mineru-token-if-any", "formula_ocr": True, "model_version": "vlm", + "rpm": 100, + "tpm": 100000, }, }, { @@ -1007,6 +1021,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1072,6 +1088,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1109,6 +1127,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1148,6 +1168,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1186,6 +1208,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1224,6 +1248,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1262,6 +1288,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1301,6 +1329,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1338,6 +1368,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -1378,6 +1410,8 @@ async def _perform_translation( "retry", "system_proxy_enable", "force_json", + "rpm", + "tpm", }, exclude_none=True, ) @@ -2507,6 +2541,8 @@ async def temp_translate( custom_prompt: Optional[str] = Body(None), model_version: Literal["pipeline", "vlm"] = Body("vlm"), glossary_dict: Optional[Dict[str, str]] = Body(None), + rpm: Optional[int] = Body(None), + tpm: Optional[int] = Body(None), ): file_name = Path(file_name) try: @@ -2530,6 +2566,8 @@ async def temp_translate( chunk_size=chunk_size, concurrent=concurrent, glossary_dict=glossary_dict, + rpm=rpm, + tpm=tpm, ), html_exporter_config=MD2HTMLExporterConfig(), ) diff --git a/docutranslate/static/index.html b/docutranslate/static/index.html index 46b7ddb..d8e552b 100644 --- a/docutranslate/static/index.html +++ b/docutranslate/static/index.html @@ -1 +1 @@ -
GitHub主页(欢迎star❤):
https://github.com/xunbu/docutranslate
交流QQ群: 1047781902
version:{{ version ? 'v' + version : '' }}
{{ t('noTaskPlaceholder') }}
{{ task.backendId || t('taskCardIdPlaceholder') }}
{{ t('taskCardFileDrop') }}
{{ t('taskCardFileSelected') }}
GitHub主页(欢迎star❤):
https://github.com/xunbu/docutranslate
交流QQ群: 1047781902
version:{{ version ? 'v' + version : '' }}
{{ t('noTaskPlaceholder') }}
{{ task.backendId || t('taskCardIdPlaceholder') }}
{{ t('taskCardFileDrop') }}
{{ t('taskCardFileSelected') }}