diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 64914e9..1d59082 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -66,7 +66,6 @@ class PromptsCounter: self.lock.release() -TIMEOUT = 600 ResultHandlerType = Callable[[str, str, logging.Logger], str] ErrorResultHandlerType = Callable[[str, logging.Logger], str] @@ -92,13 +91,6 @@ class Agent: self.model_id = config.model_id.strip() self.system_prompt = config.system_prompt or "" self.temperature = config.temperature - if USE_PROXY: - self.client = httpx.Client(proxies=get_httpx_proxies(), verify=False) - self.client_async = httpx.AsyncClient(proxies=get_httpx_proxies(), verify=False) - else: - self.client = httpx.Client(trust_env=False, proxy=None, verify=False) - self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False) - self.max_concurrent = config.max_concurrent self.timeout = config.timeout self.thinking = config.thinking @@ -133,17 +125,18 @@ class Agent: self._add_thinking_mode(data) return headers, data - async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, + async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True, + retry_count=0, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None) -> Any: if system_prompt is None: system_prompt = self.system_prompt - if prompt.strip() == "": - return prompt + # if prompt.strip() == "": + # return prompt headers, data = self._prepare_request_data(prompt, system_prompt) try: - response = await self.client_async.post( + response = await client.post( f"{self.baseurl}/chat/completions", json=data, headers=headers, @@ -167,7 +160,7 @@ class Agent: return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) self.logger.info(f"正在重试,重试次数{retry_count}") await asyncio.sleep(0.5) - return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1, + return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1, result_handler=result_handler) else: self.logger.error(f"达到重试次数上限") @@ -189,36 +182,40 @@ class Agent: semaphore = asyncio.Semaphore(max_concurrent) tasks = [] + proxies = get_httpx_proxies() if USE_PROXY else None + # 辅助协程,用于包装 self.send_async 并使用信号量 - async def send_with_semaphore(p_text: str): - async with semaphore: # 在进入代码块前获取信号量,退出时释放 - result = await self.send_async( - prompt=p_text, - system_prompt=system_prompt, - result_handler=result_handler, - error_result_handler=error_result_handler, - ) - nonlocal count - count += 1 - self.logger.info(f"协程-已完成{count}/{total}") - return result + async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client: + async def send_with_semaphore(p_text: str): + async with semaphore: # 在进入代码块前获取信号量,退出时释放 + result = await self.send_async( + client=client, + prompt=p_text, + system_prompt=system_prompt, + result_handler=result_handler, + error_result_handler=error_result_handler, + ) + nonlocal count + count += 1 + self.logger.info(f"协程-已完成{count}/{total}") + return result - for p_text in prompts: - task = asyncio.create_task(send_with_semaphore(p_text)) - tasks.append(task) + for p_text in prompts: + task = asyncio.create_task(send_with_semaphore(p_text)) + tasks.append(task) - results = await asyncio.gather(*tasks, return_exceptions=False) - return results + results = await asyncio.gather(*tasks, return_exceptions=False) + return results - def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, + def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, result_handler=None, error_result_handler=None) -> Any: if system_prompt is None: system_prompt = self.system_prompt - if prompt.strip() == "": - return prompt + # if prompt.strip() == "": + # return prompt headers, data = self._prepare_request_data(prompt, system_prompt) try: - response = self.client.post( + response = client.post( f"{self.baseurl}/chat/completions", json=data, headers=headers, @@ -228,7 +225,7 @@ class Agent: result = response.json()["choices"][0]["message"]["content"] return result if result_handler is None else result_handler(result, prompt, self.logger) except httpx.HTTPStatusError as e: - self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") + self.logger.warning(f"AI请求错误 (sync): {e.response.status_code} - {e.response.text}") print(f"prompt:\n{prompt}") self.total_error_counter.add() return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) @@ -242,15 +239,16 @@ class Agent: return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) self.logger.info(f"正在重试,重试次数{retry_count}") time.sleep(0.5) - return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1, + return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1, result_handler=result_handler) else: self.logger.error(f"达到重试次数上限") return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) - def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler, + def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter, + result_handler, error_result_handler) -> Any: - result = self.send(prompt, system_prompt, result_handler=result_handler, + result = self.send(client, prompt, system_prompt, result_handler=result_handler, error_result_handler=error_result_handler) count.add() return result @@ -274,9 +272,14 @@ class Agent: result_handlers = itertools.repeat(result_handler, len(prompts)) error_result_handlers = itertools.repeat(error_result_handler, len(prompts)) output_list = [] - with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: - results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers) - output_list = list(results_iterator) + proxies = get_httpx_proxies() if USE_PROXY else None + with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client: + clients = itertools.repeat(client, len(prompts)) + with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: + results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters, + result_handlers, + error_result_handlers) + output_list = list(results_iterator) return output_list