diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index ec09e0c..64914e9 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -1,10 +1,11 @@ import asyncio +import itertools import logging import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from threading import Lock -from typing import Literal +from typing import Literal, Callable, Any from urllib.parse import urlparse import httpx @@ -67,6 +68,9 @@ class PromptsCounter: TIMEOUT = 600 +ResultHandlerType = Callable[[str, str, logging.Logger], str] +ErrorResultHandlerType = Callable[[str, logging.Logger], str] + class Agent: _think_factory = { @@ -129,7 +133,9 @@ class Agent: self._add_thinking_mode(data) return headers, data - async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str: + async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None) -> Any: if system_prompt is None: system_prompt = self.system_prompt if prompt.strip() == "": @@ -145,12 +151,12 @@ class Agent: ) response.raise_for_status() result = response.json()["choices"][0]["message"]["content"] - return result + return result if result_handler is None else result_handler(result, prompt, self.logger) except httpx.HTTPStatusError as e: self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") print(f"prompt:\n{prompt}") self.total_error_counter.add() - return prompt + return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) except httpx.RequestError as e: self.logger.warning(f"AI请求连接错误 (async): {repr(e)}") except (KeyError, IndexError) as e: @@ -158,20 +164,23 @@ class Agent: # 如果没有正常获取结果则重试 if retry and retry_count < MAX_RETRY_COUNT: if self.total_error_counter.add(): - return prompt + return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) self.logger.info(f"正在重试,重试次数{retry_count}") await asyncio.sleep(0.5) - return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1) + return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1, + result_handler=result_handler) else: self.logger.error(f"达到重试次数上限") - return prompt + return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) async def send_prompts_async( self, prompts: list[str], system_prompt: str | None = None, - max_concurrent: int | None = None # 新增参数,默认并发数为5 - ) -> list[str]: + max_concurrent: int | None = None, # 新增参数,默认并发数为5 + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None + ) -> list[Any]: max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent total = len(prompts) self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") @@ -186,6 +195,8 @@ class Agent: result = await self.send_async( prompt=p_text, system_prompt=system_prompt, + result_handler=result_handler, + error_result_handler=error_result_handler, ) nonlocal count count += 1 @@ -199,7 +210,8 @@ class Agent: results = await asyncio.gather(*tasks, return_exceptions=False) return results - def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str: + def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, + result_handler=None, error_result_handler=None) -> Any: if system_prompt is None: system_prompt = self.system_prompt if prompt.strip() == "": @@ -214,12 +226,12 @@ class Agent: ) response.raise_for_status() result = response.json()["choices"][0]["message"]["content"] - return result + return result if result_handler is None else result_handler(result, prompt, self.logger) except httpx.HTTPStatusError as e: self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") print(f"prompt:\n{prompt}") self.total_error_counter.add() - return prompt + return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) except httpx.RequestError as e: self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}") except (KeyError, IndexError) as e: @@ -227,16 +239,19 @@ class Agent: # 如果没有正常获取结果则重试 if retry and retry_count < MAX_RETRY_COUNT: if self.total_error_counter.add(): - return prompt + return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) self.logger.info(f"正在重试,重试次数{retry_count}") time.sleep(0.5) - return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1) + return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1, + result_handler=result_handler) else: self.logger.error(f"达到重试次数上限") - return prompt + return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) - def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter) -> str: - result = self.send(prompt, system_prompt) + def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler, + error_result_handler) -> Any: + result = self.send(prompt, system_prompt, result_handler=result_handler, + error_result_handler=error_result_handler) count.add() return result @@ -244,14 +259,23 @@ class Agent: self, prompts: list[str], system_prompt: str | None = None, - ) -> list[str]: + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None + ) -> list[Any]: self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}") - system_prompts = [system_prompt] * len(prompts) - counts = [PromptsCounter(len(prompts), self.logger)] * len(prompts) + + # 创建单个计数器实例 + counter = PromptsCounter(len(prompts), self.logger) + + # 使用 itertools.repeat 将同一个实例传递给每个 map 调用 + system_prompts = itertools.repeat(system_prompt, len(prompts)) + counters = itertools.repeat(counter, len(prompts)) + result_handlers = itertools.repeat(result_handler, len(prompts)) + error_result_handlers = itertools.repeat(error_result_handler, len(prompts)) output_list = [] with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: - results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counts) + results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers) output_list = list(results_iterator) return output_list diff --git a/docutranslate/agents/segments_agent.py b/docutranslate/agents/segments_agent.py index c4e0682..de09569 100644 --- a/docutranslate/agents/segments_agent.py +++ b/docutranslate/agents/segments_agent.py @@ -2,11 +2,12 @@ import asyncio import json from dataclasses import dataclass from json import JSONDecodeError +from logging import Logger from json_repair import json_repair from docutranslate.agents import AgentConfig, Agent -from docutranslate.utils.json_utils import segments2json_chunks +from docutranslate.utils.json_utils import segments2json_chunks @dataclass @@ -43,23 +44,38 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string. if config.custom_prompt: self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n' + def _result_handler(self, result: str, origin_prompt: str, logger: Logger): + try: + result = json_repair.loads(result) + except: + logger.error("结果不能正确解析") + return self._error_result_handler(origin_prompt, logger) + return result + + def _error_result_handler(self, origin_prompt: str, logger: Logger): + try: + return json_repair.loads(origin_prompt) + except: + logger.error("prompt不是json格式") + return origin_prompt + def send_segments(self, segments: list[str], chunk_size: int): indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] - translated_chunks = super().send_prompts(prompts=prompts) + translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler, + error_result_handler=self._error_result_handler) indexed_translated = indexed_originals.copy() - for chunk_str in translated_chunks: + for chunk in translated_chunks: try: - translated_part = json_repair.loads(chunk_str) - for key, val in translated_part.items(): + for key, val in chunk.items(): if key in indexed_translated: indexed_translated[key] = val except JSONDecodeError as e: - self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}") + self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}") except ValueError as e: self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}") except Exception as e: - self.logger.info(f"send_segments错误:{e.__repr__()}") + self.logger.info(f"send_segments发生错误:{e.__repr__()}") # 初始化结果列表 result = [] @@ -79,28 +95,28 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string. # todo:增加协程粒度 async def send_segments_async(self, segments: list[str], chunk_size: int): - indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks,segments, chunk_size) + indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments, + chunk_size) prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] - translated_chunks = await super().send_prompts_async(prompts=prompts) + translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler, + error_result_handler=self._error_result_handler) indexed_translated = indexed_originals.copy() - for chunk_str in translated_chunks: + for chunk in translated_chunks: try: - translated_part = json_repair.loads(chunk_str) - for key, val in translated_part.items(): + for key, val in chunk.items(): if key in indexed_translated: indexed_translated[key] = val except JSONDecodeError as e: - self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}") + self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}") except ValueError as e: self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}") except Exception as e: - self.logger.info(f"send_segments错误:{e.__repr__()}") - + self.logger.info(f"send_segments发生错误:{e.__repr__()}") # 初始化结果列表 result = [] last_end = 0 - ls=list(indexed_translated.values()) + ls = list(indexed_translated.values()) for start, end in merged_indices_list: # 添加未处理的部分 result.extend(ls[last_end:start])