From 190ba0143028e7335b98aded3801854c5b41c8c8 Mon Sep 17 00:00:00 2001 From: xunbu Date: Fri, 1 Aug 2025 13:19:48 +0800 Subject: [PATCH] =?UTF-8?q?agent=E9=85=8D=E7=BD=AE=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docutranslate/agents/__init__.py | 4 +- docutranslate/agents/agent.py | 32 +++++----- docutranslate/agents/markdown_agent.py | 59 ++++--------------- docutranslate/agents/txt_agent.py | 20 +++---- .../translator/ai_translator/md_translator.py | 23 ++++---- .../ai_translator/txt_translator.py | 23 ++++---- 6 files changed, 62 insertions(+), 99 deletions(-) diff --git a/docutranslate/agents/__init__.py b/docutranslate/agents/__init__.py index d918a04..38147a1 100644 --- a/docutranslate/agents/__init__.py +++ b/docutranslate/agents/__init__.py @@ -1,2 +1,2 @@ -from .agent import Agent, AgentArgs -from .markdown_agent import MDRefineAgent, MDTranslateAgent +from .agent import Agent, AgentConfig +from .markdown_agent import MDTranslateAgent diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index c1ff2b9..02aba40 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -2,8 +2,8 @@ import asyncio import logging import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from threading import Lock -from typing import TypedDict import httpx @@ -13,15 +13,16 @@ MAX_RETRY_COUNT = 2 MAX_TOTAL_ERROR_COUNT = 10 -class AgentArgs(TypedDict, total=False): +@dataclass(kw_only=True) +class AgentConfig: + logger: logging.Logger baseurl: str key: str model_id: str system_prompt: str | None - temperature: float - max_concurrent: int - timeout: int - logger: logging.Logger + temperature: float = 0.7 + max_concurrent: int = 30 + timeout: int = 2000 class TotalErrorCounter: @@ -61,21 +62,20 @@ TIMEOUT = 600 class Agent: - def __init__(self, baseurl: str, key: str | None, model_id: str, system_prompt: str | None = None, temperature=0.7, - max_concurrent=15, timeout: int = TIMEOUT, logger: logging.Logger | None = None): - self.baseurl = baseurl.strip() + def __init__(self, config: AgentConfig): + self.baseurl = config.baseurl.strip() if self.baseurl.endswith("/"): self.baseurl = self.baseurl[:-1] - self.key = key.strip() or "xx" - self.model_id = model_id.strip() - self.system_prompt = system_prompt or "" - self.temperature = temperature + self.key = config.key.strip() or "xx" + self.model_id = config.model_id.strip() + self.system_prompt = config.system_prompt or "" + self.temperature = config.temperature self.client = httpx.Client(trust_env=False, proxy=None, verify=False) self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False) - self.max_concurrent = max_concurrent - self.timeout = timeout + self.max_concurrent = config.max_concurrent + self.timeout = config.timeout - self.logger = logger if logger else global_logger + self.logger = config.logger or global_logger self.total_error_counter = TotalErrorCounter(logger=self.logger) def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): diff --git a/docutranslate/agents/markdown_agent.py b/docutranslate/agents/markdown_agent.py index 4aac3be..5a885c4 100644 --- a/docutranslate/agents/markdown_agent.py +++ b/docutranslate/agents/markdown_agent.py @@ -1,60 +1,21 @@ -from typing import Unpack, NotRequired +from dataclasses import dataclass -from .agent import Agent, AgentArgs +from .agent import Agent, AgentConfig -class MDTranslateAgentArgs(AgentArgs, total=True): +@dataclass +class MDTranslateAgentConfig(AgentConfig): to_lang:str - custom_prompt:NotRequired[str] - -class MDRefineAgent(Agent): - def __init__(self, custom_prompt=None, **kwargs: Unpack[AgentArgs]): - super().__init__(**kwargs) - self.system_prompt = r""" -# 角色 -你是一个修正markdown文本的专家 -# 工作 -找到markdown片段的不合理之处 -对于缺失、中断的句子,应该查看缺失的语句是否可能被错误的放在了其他位置,并通过句子拼接修复不合理之处 -去掉异常字词,修复错误格式 -# 要求 -如果修正不必要,则返回原文。 -不要解释,不要注释。 -不要修改标题的级别(如一级标题不要修改为二级标题) -形如的占位符不要改变【非常重要】 -code、latex和HTML保持结构 -所有公式(包括短公式)都应该是latex公式 -公式无论长短必须表示为能被解析的合法latex公式,公式需被$或\\(\\)或$$正确包裹,如不正确则进行修正 -# 输出 -修正后的markdown纯文本(不是markdown代码块) -# 示例 -## 修正文本流 -输入: -什么名字 -你叫 -输出: -你叫什么名字 -## 去掉异常字词与修正公式(行内公式使用$包裹) -输入: -一道\题@#目:c_0+1=2,\(c 0\)等于几 -{c_0,c_1,c^2}是一个集合 -输出: -一道题目:$c_0+1=2$,$c_0$等于几 -{$c_0$,$c_1$,$c^2$}是一个集合""" - if custom_prompt: - self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n' - self.system_prompt += r'\no_think' - + custom_prompt:str|None=None class MDTranslateAgent(Agent): - def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]): - print(f"custom_prompt:{custom_prompt}") - super().__init__(**kwargs) + def __init__(self,config:MDTranslateAgentConfig): + super().__init__(config) self.system_prompt = f""" # 角色 你是一个专业的机器翻译引擎 # 工作 翻译输入的markdown文本 -目标语言{to_lang} +目标语言{config.to_lang} # 要求 翻译要求专业准确 不输出任何解释和注释 @@ -81,6 +42,6 @@ The equation is E=mc 2. This is famous. 这个方程是 $E=mc^2$。这很有名。 $$1+1=2$$ \\((c_0,c_1,c_2^2)\\)是一个坐标。""" - if custom_prompt: - self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n' + if config.custom_prompt: + self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n' self.system_prompt += r'\no_think' diff --git a/docutranslate/agents/txt_agent.py b/docutranslate/agents/txt_agent.py index 386fd2b..946396c 100644 --- a/docutranslate/agents/txt_agent.py +++ b/docutranslate/agents/txt_agent.py @@ -1,23 +1,23 @@ -from typing import NotRequired, Unpack +from dataclasses import dataclass -from docutranslate.agents import AgentArgs, Agent +from docutranslate.agents import AgentConfig, Agent -class TXTTranslateAgentArgs(AgentArgs, total=True): +@dataclass +class TXTTranslateAgentConfig(AgentConfig): to_lang: str - custom_prompt: NotRequired[str] + custom_prompt: str | None = None class TXTTranslateAgent(Agent): - def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]): - print(f"custom_prompt:{custom_prompt}") - super().__init__(**kwargs) + def __init__(self, config: TXTTranslateAgentConfig): + super().__init__(config) self.system_prompt = f""" # 角色 你是一个专业的机器翻译引擎 # 工作 翻译输入的txt文本 -目标语言{to_lang} +目标语言{config.to_lang} # 要求 翻译要求专业准确 不输出任何解释和注释 @@ -25,6 +25,6 @@ class TXTTranslateAgent(Agent): # 输出 翻译后的txt译文纯文本 """ - if custom_prompt: - self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n' + if config.custom_prompt: + self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n' self.system_prompt += r'\no_think' diff --git a/docutranslate/translator/ai_translator/md_translator.py b/docutranslate/translator/ai_translator/md_translator.py index 80b07fe..7ef6163 100644 --- a/docutranslate/translator/ai_translator/md_translator.py +++ b/docutranslate/translator/ai_translator/md_translator.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Self from docutranslate.agents import MDTranslateAgent +from docutranslate.agents.markdown_agent import MDTranslateAgentConfig from docutranslate.context.md_mask_context import MDMaskUrisContext from docutranslate.ir.markdown_document import MarkdownDocument from docutranslate.translator.ai_translator.base import AiTranslatorConfig @@ -16,21 +17,21 @@ class MDTranslatorConfig(AiTranslatorConfig): ... - class MDTranslator(Translator): def __init__(self, config: MDTranslatorConfig): super().__init__(config=config) self.chunk_size = config.chunk_size - self.translate_agent = MDTranslateAgent(custom_prompt=config.custom_prompt, - to_lang=config.to_lang, - baseurl=config.base_url, - key=config.api_key, - model_id=config.model_id, - system_prompt=None, - temperature=config.temperature, - max_concurrent=config.concurrent, - timeout=config.timeout, - logger=self.logger) + agent_config = MDTranslateAgentConfig(custom_prompt=config.custom_prompt, + to_lang=config.to_lang, + baseurl=config.base_url, + key=config.api_key, + model_id=config.model_id, + system_prompt=None, + temperature=config.temperature, + max_concurrent=config.concurrent, + timeout=config.timeout, + logger=self.logger) + self.translate_agent = MDTranslateAgent(agent_config) def translate(self, document: MarkdownDocument) -> Self: self.logger.info("正在翻译markdown") diff --git a/docutranslate/translator/ai_translator/txt_translator.py b/docutranslate/translator/ai_translator/txt_translator.py index 128bf5f..4ded8c8 100644 --- a/docutranslate/translator/ai_translator/txt_translator.py +++ b/docutranslate/translator/ai_translator/txt_translator.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Self -from docutranslate.agents.txt_agent import TXTTranslateAgent +from docutranslate.agents.txt_agent import TXTTranslateAgent, TXTTranslateAgentConfig from docutranslate.ir.document import Document from docutranslate.translator.ai_translator.base import AiTranslatorConfig from docutranslate.translator.base import Translator @@ -17,16 +17,17 @@ class TXTTranslator(Translator): def __init__(self, config: TXTTranslatorConfig): super().__init__(config=config) self.chunk_size = config.chunk_size - self.translate_agent = TXTTranslateAgent(custom_prompt=config.custom_prompt, - to_lang=config.to_lang, - baseurl=config.base_url, - key=config.api_key, - model_id=config.model_id, - system_prompt=None, - temperature=config.temperature, - max_concurrent=config.concurrent, - timeout=config.timeout, - logger=self.logger) + agent_config = TXTTranslateAgentConfig(custom_prompt=config.custom_prompt, + to_lang=config.to_lang, + baseurl=config.base_url, + key=config.api_key, + model_id=config.model_id, + system_prompt=None, + temperature=config.temperature, + max_concurrent=config.concurrent, + timeout=config.timeout, + logger=self.logger) + self.translate_agent = TXTTranslateAgent(agent_config) def translate(self, document: Document) -> Self: self.logger.info("正在翻译txt")