diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 5059ae7..0003bc7 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -34,11 +34,11 @@ class AgentConfig: class TotalErrorCounter: - def __init__(self, logger: logging.Logger,max_errors_count=10): + def __init__(self, logger: logging.Logger, max_errors_count=10): self.lock = Lock() self.count = 0 self.logger = logger - self.max_errors_count=max_errors_count + self.max_errors_count = max_errors_count def add(self): self.lock.acquire() @@ -67,7 +67,7 @@ class PromptsCounter: self.lock.release() - +PreSendHandlerType = Callable[[str, str], tuple[str, str]] ResultHandlerType = Callable[[str, str, logging.Logger], str] ErrorResultHandlerType = Callable[[str, logging.Logger], str] @@ -128,10 +128,13 @@ class Agent: async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, + pre_send_handler: PreSendHandlerType = None, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None) -> Any: if system_prompt is None: system_prompt = self.system_prompt + if pre_send_handler: + system_prompt, prompt = pre_send_handler(system_prompt, prompt) # if prompt.strip() == "": # return prompt headers, data = self._prepare_request_data(prompt, system_prompt) @@ -172,6 +175,7 @@ class Agent: prompts: list[str], system_prompt: str | None = None, max_concurrent: int | None = None, # 新增参数,默认并发数为5 + pre_send_handler: PreSendHandlerType = None, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None ) -> list[Any]: @@ -179,7 +183,7 @@ class Agent: total = len(prompts) self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") - self.total_error_counter.max_errors_count=len(prompts) // MAX_REQUESTS_PER_ERROR #允许多少个异常 + self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常 count = 0 semaphore = asyncio.Semaphore(max_concurrent) tasks = [] @@ -194,6 +198,7 @@ class Agent: client=client, prompt=p_text, system_prompt=system_prompt, + pre_send_handler=pre_send_handler, result_handler=result_handler, error_result_handler=error_result_handler, ) @@ -210,9 +215,11 @@ class Agent: return results def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, - result_handler=None, error_result_handler=None) -> Any: + pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any: if system_prompt is None: system_prompt = self.system_prompt + if pre_send_handler: + system_prompt, prompt = pre_send_handler(system_prompt, prompt) # if prompt.strip() == "": # return prompt headers, data = self._prepare_request_data(prompt, system_prompt) @@ -259,6 +266,7 @@ class Agent: self, prompts: list[str], system_prompt: str | None = None, + pre_send_handler:PreSendHandlerType=None, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None ) -> list[Any]: @@ -271,6 +279,7 @@ class Agent: # 使用 itertools.repeat 将同一个实例传递给每个 map 调用 system_prompts = itertools.repeat(system_prompt, len(prompts)) counters = itertools.repeat(counter, len(prompts)) + pre_send_handlers=itertools.repeat(pre_send_handler,len(prompts)) result_handlers = itertools.repeat(result_handler, len(prompts)) error_result_handlers = itertools.repeat(error_result_handler, len(prompts)) output_list = [] @@ -279,6 +288,7 @@ class Agent: clients = itertools.repeat(client, len(prompts)) with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters, + pre_send_handlers, result_handlers, error_result_handlers) output_list = list(results_iterator) diff --git a/docutranslate/agents/markdown_agent.py b/docutranslate/agents/markdown_agent.py index 29ff320..55a12df 100644 --- a/docutranslate/agents/markdown_agent.py +++ b/docutranslate/agents/markdown_agent.py @@ -1,14 +1,18 @@ from dataclasses import dataclass from .agent import Agent, AgentConfig +from ..glossary.glossary import Glossary + @dataclass class MDTranslateAgentConfig(AgentConfig): - to_lang:str - custom_prompt:str|None=None + to_lang: str + custom_prompt: str | None = None + glossary_dict: dict[str, str] | None = None + class MDTranslateAgent(Agent): - def __init__(self,config:MDTranslateAgentConfig): + def __init__(self, config: MDTranslateAgentConfig): super().__init__(config) self.system_prompt = f""" # Role @@ -48,5 +52,17 @@ Output: 这个方程是 $E=mc^2$。这很有名。 $$1+1=2$$ \\((c_0,c_1,c_2^2)\\)是一个坐标。""" + self.custom_prompt = config.custom_prompt if config.custom_prompt: - self.system_prompt += "\n# **Important rules or background** \n" + config.custom_prompt + '\n' + self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n' + self.glossary_dict = config.glossary_dict + + def _pre_send_handler(self, system_prompt, prompt): + if self.glossary_dict: + glossary = Glossary(glossary_dict=self.glossary_dict) + system_prompt += glossary.append_system_prompt(prompt) + return system_prompt, prompt + def send_chunks(self, prompts: list[str]): + return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler) + async def send_chunks_async(self, prompts: list[str]): + return await super().send_prompts_async(prompts=prompts,pre_send_handler=self._pre_send_handler) \ No newline at end of file diff --git a/docutranslate/agents/segments_agent.py b/docutranslate/agents/segments_agent.py index b9aed65..d0eef49 100644 --- a/docutranslate/agents/segments_agent.py +++ b/docutranslate/agents/segments_agent.py @@ -7,6 +7,7 @@ from logging import Logger from json_repair import json_repair from docutranslate.agents import AgentConfig, Agent +from docutranslate.glossary.glossary import Glossary from docutranslate.utils.json_utils import segments2json_chunks @@ -14,6 +15,7 @@ from docutranslate.utils.json_utils import segments2json_chunks class SegmentsTranslateAgentConfig(AgentConfig): to_lang: str custom_prompt: str | None = None + glossary_dict: dict[str, str] | None = None class SegmentsTranslateAgent(Agent): @@ -43,13 +45,21 @@ Output {r'{"0":"你好","1":"苹果","2":true,"3":"错误"}'} Warning: Never wrap the entire JSON object in quotes to make it a single string. Never wrap the JSON text in ```. """ + self.custom_prompt = config.custom_prompt if config.custom_prompt: - self.system_prompt += "\n# **Important rules or background** for segments for translation \n" + config.custom_prompt + '\n' + self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n' + self.glossary_dict = config.glossary_dict + + def _pre_send_handler(self, system_prompt, prompt): + if self.glossary_dict: + glossary = Glossary(glossary_dict=self.glossary_dict) + system_prompt += glossary.append_system_prompt(prompt) + return system_prompt, prompt def _result_handler(self, result: str, origin_prompt: str, logger: Logger): try: result = json_repair.loads(result) - if not isinstance(result,dict): + if not isinstance(result, dict): raise ValueError("agent返回结果不是dict的json形式") except: logger.error("结果不能正确解析") @@ -66,7 +76,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string. def send_segments(self, segments: list[str], chunk_size: int): indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] - translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler, + translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler, + result_handler=self._result_handler, error_result_handler=self._error_result_handler) indexed_translated = indexed_originals.copy() for chunk in translated_chunks: @@ -102,7 +113,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string. indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments, chunk_size) prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] - translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler, + translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler, + result_handler=self._result_handler, error_result_handler=self._error_result_handler) indexed_translated = indexed_originals.copy() for chunk in translated_chunks: diff --git a/docutranslate/agents/txt_agent.py b/docutranslate/agents/txt_agent.py index c599c58..44a5142 100644 --- a/docutranslate/agents/txt_agent.py +++ b/docutranslate/agents/txt_agent.py @@ -1,12 +1,14 @@ from dataclasses import dataclass from docutranslate.agents import AgentConfig, Agent +from docutranslate.glossary.glossary import Glossary @dataclass class TXTTranslateAgentConfig(AgentConfig): to_lang: str custom_prompt: str | None = None + glossary_dict: dict[str, str] | None = None class TXTTranslateAgent(Agent): @@ -30,5 +32,17 @@ Target language: {config.to_lang} # Output The translated txt text as plain text. """ + self.custom_prompt = config.custom_prompt if config.custom_prompt: - self.system_prompt += "\n# **Important rules or background**\n" + config.custom_prompt + '\n' + self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n' + self.glossary_dict = config.glossary_dict + + def _pre_send_handler(self, system_prompt, prompt): + if self.glossary_dict: + glossary = Glossary(glossary_dict=self.glossary_dict) + system_prompt += glossary.append_system_prompt(prompt) + return system_prompt, prompt + def send_chunks(self, prompts: list[str]): + return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler) + async def send_chunks_async(self, prompts: list[str]): + return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler) diff --git a/docutranslate/glossary/__init__.py b/docutranslate/glossary/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docutranslate/glossary/glossary.py b/docutranslate/glossary/glossary.py new file mode 100644 index 0000000..8c9dbc1 --- /dev/null +++ b/docutranslate/glossary/glossary.py @@ -0,0 +1,15 @@ +class Glossary: + def __init__(self,glossary_dict:dict[str:str]|None=None): + self.glossary_dict=glossary_dict + + def update(self,update_dict:dict[str:str]): + for src,dst in update_dict.items(): + if src not in self.glossary_dict: + self.glossary_dict[src]=dst + + def append_system_prompt(self,text:str): + prompt="\n以下为参考术语表:\n" + for src,dst in self.glossary_dict: + if src in text: + prompt+=f"{src}=>{dst}\n" + prompt+="术语表结束\n"