diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 8e4b03a..d8389da 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -82,6 +82,113 @@ class PromptsCounter: self.logger.info(f"多线程-已完成:{self.count}/{self.total}") +def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: + """ + 从API响应中提取token信息 + + 支持多种response格式: + 1. 格式1: usage.input_tokens_details.cached_tokens 和 usage.output_tokens_details.reasoning_tokens + 2. 格式2: usage.prompt_tokens_details.cached_tokens + 3. 格式3: usage.prompt_cache_hit_tokens 和 usage.completion_tokens_details.reasoning_tokens + + Args: + response_data: API响应数据 + + Returns: + tuple: (input_tokens, cached_tokens, output_tokens, reasoning_tokens) + """ + if "usage" not in response_data: + return 0, 0, 0, 0 + + usage = response_data["usage"] + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + + # 初始化token详细统计 + cached_tokens = 0 + reasoning_tokens = 0 + + # 尝试从不同格式获取cached_tokens + # 格式1: input_tokens_details.cached_tokens + if ( + "input_tokens_details" in usage + and "cached_tokens" in usage["input_tokens_details"] + ): + cached_tokens = usage["input_tokens_details"]["cached_tokens"] + # 格式2: prompt_tokens_details.cached_tokens + elif ( + "prompt_tokens_details" in usage + and "cached_tokens" in usage["prompt_tokens_details"] + ): + cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] + # 格式3: prompt_cache_hit_tokens (直接在usage下) + elif "prompt_cache_hit_tokens" in usage: + cached_tokens = usage["prompt_cache_hit_tokens"] + + # 尝试从不同格式获取reasoning_tokens + # 格式1: output_tokens_details.reasoning_tokens + if ( + "output_tokens_details" in usage + and "reasoning_tokens" in usage["output_tokens_details"] + ): + reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"] + # 格式2: completion_tokens_details.reasoning_tokens + elif ( + "completion_tokens_details" in usage + and "reasoning_tokens" in usage["completion_tokens_details"] + ): + reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] + + return input_tokens, cached_tokens, output_tokens, reasoning_tokens + + +class TokenCounter: + def __init__(self, logger: logging.Logger): + self.lock = Lock() + self.input_tokens = 0 + self.cached_tokens = 0 + self.output_tokens = 0 + self.reasoning_tokens = 0 + self.total_tokens = 0 + self.logger = logger + + def add( + self, + input_tokens: int, + cached_tokens: int, + output_tokens: int, + reasoning_tokens: int, + ): + with self.lock: + self.input_tokens += input_tokens + self.cached_tokens += cached_tokens + self.output_tokens += output_tokens + self.reasoning_tokens += reasoning_tokens + self.total_tokens += input_tokens + output_tokens + # self.logger.debug( + # f"Token使用统计 - 输入: {self.input_tokens}(含cached: {self.cached_tokens}), " + # f"输出: {self.output_tokens}(含reasoning: {self.reasoning_tokens}), 总计: {self.total_tokens}" + # ) + + def get_stats(self): + with self.lock: + return { + "input_tokens": self.input_tokens, + "cached_tokens": self.cached_tokens, + "output_tokens": self.output_tokens, + "reasoning_tokens": self.reasoning_tokens, + "total_tokens": self.total_tokens, + } + + def reset(self): + with self.lock: + self.input_tokens = 0 + self.cached_tokens = 0 + self.output_tokens = 0 + self.reasoning_tokens = 0 + self.total_tokens = 0 + + PreSendHandlerType = Callable[[str, str], tuple[str, str]] ResultHandlerType = Callable[[str, str, logging.Logger], Any] ErrorResultHandlerType = Callable[[str, logging.Logger], Any] @@ -90,22 +197,30 @@ ErrorResultHandlerType = Callable[[str, logging.Logger], Any] class Agent: _think_factory = { "open.bigmodel.cn": ("thinking", {"type": "enabled"}, {"type": "disabled"}), - "dashscope.aliyuncs.com": ("enable_thinking ", True, False), - "ark.cn-beijing.volces.com": ("thinking", {"type": "enabled"}, {"type": "disabled"}), - "generativelanguage.googleapis.com": ("extra_body", - {"google": { - "thinking_config": { - "thinking_budget": -1, - "include_thoughts": True - } - } - }, {"google": { - "thinking_config": { - "thinking_budget": 0, - "include_thoughts": False + "dashscope.aliyuncs.com": ( + "extra_body", + {"enable_thinking": True}, + {"enable_thinking": False}, + ), + "ark.cn-beijing.volces.com": ( + "thinking", + {"type": "enabled"}, + {"type": "disabled"}, + ), + "generativelanguage.googleapis.com": ( + "extra_body", + { + "google": { + "thinking_config": {"thinking_budget": -1, "include_thoughts": True} } - }}), - "api.siliconflow.cn": ("enable_thinking", True, False) + }, + { + "google": { + "thinking_config": {"thinking_budget": 0, "include_thoughts": False} + } + }, + ), + "api.siliconflow.cn": ("enable_thinking", True, False), } def __init__(self, config: AgentConfig): @@ -126,6 +241,8 @@ class Agent: # 新增:用于统计最终未解决的错误 self.unresolved_error_lock = Lock() self.unresolved_error_count = 0 + # 新增:用于统计token使用情况 + self.token_counter = TokenCounter(logger=self.logger) def _add_thinking_mode(self, data: dict): if self.domain not in self._think_factory: @@ -136,16 +253,20 @@ class Agent: elif self.thinking == "disable": data[field_thinking] = val_disable - def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): + def _prepare_request_data( + self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 + ): if temperature is None: temperature = self.temperature - headers = {"Content-Type": "application/json", - "Authorization": f"Bearer {self.key}"} + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.key}", + } data = { "model": self.model_id, "messages": [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ], "temperature": temperature, "top_p": top_p, @@ -154,12 +275,18 @@ class Agent: self._add_thinking_mode(data) return headers, data - async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True, - retry_count=0, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, - best_partial_result: dict | None = None) -> Any: + async def send_async( + self, + client: httpx.AsyncClient, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, + best_partial_result: dict | None = None, + ) -> Any: if system_prompt is None: system_prompt = self.system_prompt if pre_send_handler: @@ -170,21 +297,42 @@ class Agent: should_retry = False is_hard_error = False # 新增标志,用于区分是否为硬错误 current_partial_result = None + input_tokens = 0 + output_tokens = 0 try: response = await client.post( f"{self.baseurl}/chat/completions", json=data, headers=headers, - timeout=self.timeout + timeout=self.timeout, ) response.raise_for_status() + # print(f"【测试】resp:\n{response.json()}") result = response.json()["choices"][0]["message"]["content"] + + # 获取token使用情况 + response_data = response.json() + input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( + extract_token_info(response_data) + ) + + # 更新token计数器 + self.token_counter.add( + input_tokens, cached_tokens, output_tokens, reasoning_tokens + ) + if retry_count > 0: - self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。") + self.logger.info( + f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。" + ) # print(f"result:=============================================================\n{result}\n================\n") - return result if result_handler is None else result_handler(result, prompt, self.logger) + return ( + result + if result_handler is None + else result_handler(result, prompt, self.logger) + ) except AgentResultError as e: self.logger.error(f"AI返回结果有误: {e}") @@ -199,7 +347,9 @@ class Agent: # 捕获硬错误 except httpx.HTTPStatusError as e: - self.logger.error(f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}") + self.logger.error( + f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}" + ) should_retry = True is_hard_error = True except httpx.RequestError as e: @@ -220,20 +370,40 @@ class Agent: if retry_count == 0: if self.total_error_counter.add(): self.logger.error("错误次数过多,已达到上限,不再重试。") - return best_partial_result if best_partial_result else ( - prompt if error_result_handler is None else error_result_handler(prompt, self.logger)) + return ( + best_partial_result + if best_partial_result + else ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) + ) elif self.total_error_counter.reach_limit(): self.logger.error("错误次数过多,已达到上限,不再为该请求重试。") - return best_partial_result if best_partial_result else ( - prompt if error_result_handler is None else error_result_handler(prompt, self.logger)) + return ( + best_partial_result + if best_partial_result + else ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) + ) self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...") await asyncio.sleep(0.5) - return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1, - pre_send_handler=pre_send_handler, - result_handler=result_handler, - error_result_handler=error_result_handler, - best_partial_result=best_partial_result) + return await self.send_async( + client, + prompt, + system_prompt, + retry=True, + retry_count=retry_count + 1, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + best_partial_result=best_partial_result, + ) else: if should_retry: self.logger.error(f"所有重试均失败,已达到重试次数上限。") @@ -245,26 +415,37 @@ class Agent: self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。") return best_partial_result - return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) + return ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) async def send_prompts_async( - self, - prompts: list[str], - system_prompt: str | None = None, - max_concurrent: int | None = None, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None + self, + prompts: list[str], + system_prompt: str | None = None, + max_concurrent: int | None = None, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: - max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent + max_concurrent = ( + self.max_concurrent if max_concurrent is None else max_concurrent + ) total = len(prompts) self.logger.info( - f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}") + f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}" + ) self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") - self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR + self.total_error_counter.max_errors_count = ( + len(prompts) // MAX_REQUESTS_PER_ERROR + ) # 新增:在每次批量发送前重置计数器 self.unresolved_error_count = 0 + # 重置token计数器 + self.token_counter.reset() count = 0 semaphore = asyncio.Semaphore(max_concurrent) @@ -274,10 +455,13 @@ class Agent: limits = httpx.Limits( max_connections=self.max_concurrent * 2, # 为重试和并发预留空间 - max_keepalive_connections=self.max_concurrent # 保持活动的连接数 + max_keepalive_connections=self.max_concurrent, # 保持活动的连接数 ) - async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False, limits=limits) as client: + async with httpx.AsyncClient( + trust_env=False, proxies=proxies, verify=False, limits=limits + ) as client: + async def send_with_semaphore(p_text: str): async with semaphore: result = await self.send_async( @@ -300,13 +484,32 @@ class Agent: results = await asyncio.gather(*tasks, return_exceptions=False) # 新增:在所有任务完成后打印未解决的错误总数 - self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}") + self.logger.info( + f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}" + ) + + # 新增:打印token使用统计 + token_stats = self.token_counter.get_stats() + self.logger.info( + f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), " + f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), " + f"总计: {token_stats['total_tokens']/1000:.2f}K" + ) return results - def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, - pre_send_handler=None, result_handler=None, error_result_handler=None, - best_partial_result: dict | None = None) -> Any: + def send( + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + pre_send_handler=None, + result_handler=None, + error_result_handler=None, + best_partial_result: dict | None = None, + ) -> Any: if system_prompt is None: system_prompt = self.system_prompt if pre_send_handler: @@ -316,21 +519,41 @@ class Agent: should_retry = False is_hard_error = False # 新增标志,用于区分是否为硬错误 current_partial_result = None + input_tokens = 0 + output_tokens = 0 try: response = client.post( f"{self.baseurl}/chat/completions", json=data, headers=headers, - timeout=self.timeout + timeout=self.timeout, ) response.raise_for_status() + result = response.json()["choices"][0]["message"]["content"] - if retry_count > 0: - self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。") + # 获取token使用情况 + response_data = response.json() + input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( + extract_token_info(response_data) + ) - return result if result_handler is None else result_handler(result, prompt, self.logger) + # 更新token计数器 + self.token_counter.add( + input_tokens, cached_tokens, output_tokens, reasoning_tokens + ) + + if retry_count > 0: + self.logger.info( + f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。" + ) + + return ( + result + if result_handler is None + else result_handler(result, prompt, self.logger) + ) except AgentResultError as e: self.logger.error(f"AI返回结果有误: {e}") should_retry = True @@ -343,7 +566,9 @@ class Agent: # 捕获硬错误 except httpx.HTTPStatusError as e: - self.logger.error(f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}") + self.logger.error( + f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}" + ) should_retry = True is_hard_error = True except httpx.RequestError as e: @@ -364,20 +589,40 @@ class Agent: if retry_count == 0: if self.total_error_counter.add(): self.logger.error("错误次数过多,已达到上限,不再重试。") - return best_partial_result if best_partial_result else ( - prompt if error_result_handler is None else error_result_handler(prompt, self.logger)) + return ( + best_partial_result + if best_partial_result + else ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) + ) elif self.total_error_counter.reach_limit(): self.logger.error("错误次数过多,已达到上限,不再为该请求重试。") - return best_partial_result if best_partial_result else ( - prompt if error_result_handler is None else error_result_handler(prompt, self.logger)) + return ( + best_partial_result + if best_partial_result + else ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) + ) self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...") time.sleep(0.5) - return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1, - pre_send_handler=pre_send_handler, - result_handler=result_handler, - error_result_handler=error_result_handler, - best_partial_result=best_partial_result) + return self.send( + client, + prompt, + system_prompt, + retry=True, + retry_count=retry_count + 1, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + best_partial_result=best_partial_result, + ) else: if should_retry: self.logger.error(f"所有重试均失败,已达到重试次数上限。") @@ -389,33 +634,55 @@ class Agent: self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。") return best_partial_result - return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) + return ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) - def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter, - pre_send_handler, - result_handler, - error_result_handler) -> Any: - result = self.send(client, prompt, system_prompt, pre_send_handler=pre_send_handler, - result_handler=result_handler, - error_result_handler=error_result_handler) + def _send_prompt_count( + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str, + count: PromptsCounter, + pre_send_handler, + result_handler, + error_result_handler, + ) -> Any: + result = self.send( + client, + prompt, + system_prompt, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + ) count.add() return result def send_prompts( - self, - prompts: list[str], - system_prompt: str | None = None, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None + self, + prompts: list[str], + system_prompt: str | None = None, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: self.logger.info( - f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}") - self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}") - self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR + f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}" + ) + self.logger.info( + f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}" + ) + self.total_error_counter.max_errors_count = ( + len(prompts) // MAX_REQUESTS_PER_ERROR + ) # 新增:在每次批量发送前重置计数器 self.unresolved_error_count = 0 + # 重置token计数器 + self.token_counter.reset() counter = PromptsCounter(len(prompts), self.logger) @@ -426,23 +693,41 @@ class Agent: error_result_handlers = itertools.repeat(error_result_handler, len(prompts)) limits = httpx.Limits( max_connections=self.max_concurrent * 2, # 允许连接复用 - max_keepalive_connections=self.max_concurrent # 保持活跃连接 + max_keepalive_connections=self.max_concurrent, # 保持活跃连接 ) proxies = get_httpx_proxies() if USE_PROXY else None - with httpx.Client(trust_env=False, proxies=proxies, verify=False, limits=limits) as client: + with httpx.Client( + trust_env=False, proxies=proxies, verify=False, limits=limits + ) as client: clients = itertools.repeat(client, len(prompts)) with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: - results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters, - pre_send_handlers, - result_handlers, - error_result_handlers) + results_iterator = executor.map( + self._send_prompt_count, + clients, + prompts, + system_prompts, + counters, + pre_send_handlers, + result_handlers, + error_result_handlers, + ) output_list = list(results_iterator) # 新增:在所有任务完成后打印未解决的错误总数 - self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}") + self.logger.info( + f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}" + ) + + # 新增:打印token使用统计 + token_stats = self.token_counter.get_stats() + self.logger.info( + f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), " + f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), " + f"总计: {token_stats['total_tokens']/1000:.2f}K" + ) return output_list -if __name__ == '__main__': +if __name__ == "__main__": pass