From 29fb97d338bf9a229a4c286c8668deda39bd8a43 Mon Sep 17 00:00:00 2001 From: xunbu Date: Sat, 13 Sep 2025 11:02:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=B0=E8=BE=BE=E4=B8=8A=E9=99=90=E6=97=B6?= =?UTF-8?q?=E4=B9=9F=E5=BA=94=E8=AF=A5=E5=A2=9E=E5=8A=A0=E6=9C=AA=E8=A7=A3?= =?UTF-8?q?=E5=86=B3=E9=94=99=E8=AF=AF=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docutranslate/agents/agent.py | 139 ++++++++++++++++++---------------- 1 file changed, 74 insertions(+), 65 deletions(-) diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index b51efca..2675683 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -111,14 +111,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: # 尝试从不同格式获取cached_tokens # 格式1: input_tokens_details.cached_tokens if ( - "input_tokens_details" in usage - and "cached_tokens" in usage["input_tokens_details"] + "input_tokens_details" in usage + and "cached_tokens" in usage["input_tokens_details"] ): cached_tokens = usage["input_tokens_details"]["cached_tokens"] # 格式2: prompt_tokens_details.cached_tokens elif ( - "prompt_tokens_details" in usage - and "cached_tokens" in usage["prompt_tokens_details"] + "prompt_tokens_details" in usage + and "cached_tokens" in usage["prompt_tokens_details"] ): cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] # 格式3: prompt_cache_hit_tokens (直接在usage下) @@ -128,14 +128,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: # 尝试从不同格式获取reasoning_tokens # 格式1: output_tokens_details.reasoning_tokens if ( - "output_tokens_details" in usage - and "reasoning_tokens" in usage["output_tokens_details"] + "output_tokens_details" in usage + and "reasoning_tokens" in usage["output_tokens_details"] ): reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"] # 格式2: completion_tokens_details.reasoning_tokens elif ( - "completion_tokens_details" in usage - and "reasoning_tokens" in usage["completion_tokens_details"] + "completion_tokens_details" in usage + and "reasoning_tokens" in usage["completion_tokens_details"] ): reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] @@ -153,11 +153,11 @@ class TokenCounter: self.logger = logger def add( - self, - input_tokens: int, - cached_tokens: int, - output_tokens: int, - reasoning_tokens: int, + self, + input_tokens: int, + cached_tokens: int, + output_tokens: int, + reasoning_tokens: int, ): with self.lock: self.input_tokens += input_tokens @@ -256,7 +256,7 @@ class Agent: data[field_thinking] = val_disable def _prepare_request_data( - self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 + self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 ): if temperature is None: temperature = self.temperature @@ -278,16 +278,16 @@ class Agent: return headers, data async def send_async( - self, - client: httpx.AsyncClient, - prompt: str, - system_prompt: None | str = None, - retry=True, - retry_count=0, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, - best_partial_result: dict | None = None, + self, + client: httpx.AsyncClient, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, + best_partial_result: dict | None = None, ) -> Any: if system_prompt is None: system_prompt = self.system_prompt @@ -325,9 +325,7 @@ class Agent: ) if retry_count > 0: - self.logger.info( - f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。" - ) + self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。") # print(f"result:=============================================================\n{result}\n================\n") return ( @@ -372,6 +370,9 @@ class Agent: if retry_count == 0: if self.total_error_counter.add(): self.logger.error("错误次数过多,已达到上限,不再重试。") + # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 + with self.unresolved_error_lock: + self.unresolved_error_count += 1 return ( best_partial_result if best_partial_result @@ -383,6 +384,9 @@ class Agent: ) elif self.total_error_counter.reach_limit(): self.logger.error("错误次数过多,已达到上限,不再为该请求重试。") + # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 + with self.unresolved_error_lock: + self.unresolved_error_count += 1 return ( best_partial_result if best_partial_result @@ -424,13 +428,13 @@ class Agent: ) async def send_prompts_async( - self, - prompts: list[str], - system_prompt: str | None = None, - max_concurrent: int | None = None, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, + self, + prompts: list[str], + system_prompt: str | None = None, + max_concurrent: int | None = None, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: max_concurrent = ( self.max_concurrent if max_concurrent is None else max_concurrent @@ -441,7 +445,7 @@ class Agent: ) self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") self.total_error_counter.max_errors_count = ( - len(prompts) // MAX_REQUESTS_PER_ERROR + len(prompts) // MAX_REQUESTS_PER_ERROR ) # 新增:在每次批量发送前重置计数器 @@ -461,8 +465,9 @@ class Agent: ) async with httpx.AsyncClient( - trust_env=False, proxies=proxies, verify=False, limits=limits + trust_env=False, proxies=proxies, verify=False, limits=limits ) as client: + async def send_with_semaphore(p_text: str): async with semaphore: result = await self.send_async( @@ -500,16 +505,16 @@ class Agent: return results def send( - self, - client: httpx.Client, - prompt: str, - system_prompt: None | str = None, - retry=True, - retry_count=0, - pre_send_handler=None, - result_handler=None, - error_result_handler=None, - best_partial_result: dict | None = None, + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + pre_send_handler=None, + result_handler=None, + error_result_handler=None, + best_partial_result: dict | None = None, ) -> Any: if system_prompt is None: system_prompt = self.system_prompt @@ -546,9 +551,7 @@ class Agent: ) if retry_count > 0: - self.logger.info( - f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。" - ) + self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。") return ( result @@ -590,6 +593,9 @@ class Agent: if retry_count == 0: if self.total_error_counter.add(): self.logger.error("错误次数过多,已达到上限,不再重试。") + # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 + with self.unresolved_error_lock: + self.unresolved_error_count += 1 return ( best_partial_result if best_partial_result @@ -601,6 +607,9 @@ class Agent: ) elif self.total_error_counter.reach_limit(): self.logger.error("错误次数过多,已达到上限,不再为该请求重试。") + # 新增:当因为达到错误上限而不再重试时,增加未解决错误计数 + with self.unresolved_error_lock: + self.unresolved_error_count += 1 return ( best_partial_result if best_partial_result @@ -642,14 +651,14 @@ class Agent: ) def _send_prompt_count( - self, - client: httpx.Client, - prompt: str, - system_prompt: None | str, - count: PromptsCounter, - pre_send_handler, - result_handler, - error_result_handler, + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str, + count: PromptsCounter, + pre_send_handler, + result_handler, + error_result_handler, ) -> Any: result = self.send( client, @@ -663,12 +672,12 @@ class Agent: return result def send_prompts( - self, - prompts: list[str], - system_prompt: str | None = None, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, + self, + prompts: list[str], + system_prompt: str | None = None, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: self.logger.info( f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}" @@ -677,7 +686,7 @@ class Agent: f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}" ) self.total_error_counter.max_errors_count = ( - len(prompts) // MAX_REQUESTS_PER_ERROR + len(prompts) // MAX_REQUESTS_PER_ERROR ) # 新增:在每次批量发送前重置计数器 @@ -698,7 +707,7 @@ class Agent: ) proxies = get_httpx_proxies() if USE_PROXY else None with httpx.Client( - trust_env=False, proxies=proxies, verify=False, limits=limits + trust_env=False, proxies=proxies, verify=False, limits=limits ) as client: clients = itertools.repeat(client, len(prompts)) with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: