agent配置化

This commit is contained in:
xunbu
2025-08-01 13:19:48 +08:00
parent 89b1963b97
commit 190ba01430
6 changed files with 62 additions and 99 deletions

View File

@@ -1,2 +1,2 @@
from .agent import Agent, AgentArgs
from .markdown_agent import MDRefineAgent, MDTranslateAgent
from .agent import Agent, AgentConfig
from .markdown_agent import MDTranslateAgent

View File

@@ -2,8 +2,8 @@ import asyncio
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from threading import Lock
from typing import TypedDict
import httpx
@@ -13,15 +13,16 @@ MAX_RETRY_COUNT = 2
MAX_TOTAL_ERROR_COUNT = 10
class AgentArgs(TypedDict, total=False):
@dataclass(kw_only=True)
class AgentConfig:
logger: logging.Logger
baseurl: str
key: str
model_id: str
system_prompt: str | None
temperature: float
max_concurrent: int
timeout: int
logger: logging.Logger
temperature: float = 0.7
max_concurrent: int = 30
timeout: int = 2000
class TotalErrorCounter:
@@ -61,21 +62,20 @@ TIMEOUT = 600
class Agent:
def __init__(self, baseurl: str, key: str | None, model_id: str, system_prompt: str | None = None, temperature=0.7,
max_concurrent=15, timeout: int = TIMEOUT, logger: logging.Logger | None = None):
self.baseurl = baseurl.strip()
def __init__(self, config: AgentConfig):
self.baseurl = config.baseurl.strip()
if self.baseurl.endswith("/"):
self.baseurl = self.baseurl[:-1]
self.key = key.strip() or "xx"
self.model_id = model_id.strip()
self.system_prompt = system_prompt or ""
self.temperature = temperature
self.key = config.key.strip() or "xx"
self.model_id = config.model_id.strip()
self.system_prompt = config.system_prompt or ""
self.temperature = config.temperature
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
self.max_concurrent = max_concurrent
self.timeout = timeout
self.max_concurrent = config.max_concurrent
self.timeout = config.timeout
self.logger = logger if logger else global_logger
self.logger = config.logger or global_logger
self.total_error_counter = TotalErrorCounter(logger=self.logger)
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):

View File

@@ -1,60 +1,21 @@
from typing import Unpack, NotRequired
from dataclasses import dataclass
from .agent import Agent, AgentArgs
from .agent import Agent, AgentConfig
class MDTranslateAgentArgs(AgentArgs, total=True):
@dataclass
class MDTranslateAgentConfig(AgentConfig):
to_lang:str
custom_prompt:NotRequired[str]
class MDRefineAgent(Agent):
def __init__(self, custom_prompt=None, **kwargs: Unpack[AgentArgs]):
super().__init__(**kwargs)
self.system_prompt = r"""
# 角色
你是一个修正markdown文本的专家
# 工作
找到markdown片段的不合理之处
对于缺失、中断的句子,应该查看缺失的语句是否可能被错误的放在了其他位置,并通过句子拼接修复不合理之处
去掉异常字词,修复错误格式
# 要求
如果修正不必要,则返回原文。
不要解释,不要注释。
不要修改标题的级别(如一级标题不要修改为二级标题)
形如<ph-ads231>的占位符不要改变【非常重要】
code、latex和HTML保持结构
所有公式包括短公式都应该是latex公式
公式无论长短必须表示为能被解析的合法latex公式公式需被$或\\(\\)或$$正确包裹,如不正确则进行修正
# 输出
修正后的markdown纯文本不是markdown代码块
# 示例
## 修正文本流
输入:
什么名字
你叫
输出:
你叫什么名字
## 去掉异常字词与修正公式(行内公式使用$包裹)
输入:
一道\题@#目<ph-12asd2>:c_0+1=2\(c 0\)等于几
{c_0,c_1,c^2}是一个集合
输出:
一道题目<ph-12asd2>:$c_0+1=2$$c_0$等于几
{$c_0$,$c_1$,$c^2$}是一个集合"""
if custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
self.system_prompt += r'\no_think'
custom_prompt:str|None=None
class MDTranslateAgent(Agent):
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]):
print(f"custom_prompt:{custom_prompt}")
super().__init__(**kwargs)
def __init__(self,config:MDTranslateAgentConfig):
super().__init__(config)
self.system_prompt = f"""
# 角色
你是一个专业的机器翻译引擎
# 工作
翻译输入的markdown文本
目标语言{to_lang}
目标语言{config.to_lang}
# 要求
翻译要求专业准确
不输出任何解释和注释
@@ -81,6 +42,6 @@ The equation is E=mc 2. This is famous.
这个方程是 $E=mc^2$。这很有名。
$$1+1=2$$
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
if custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
if config.custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
self.system_prompt += r'\no_think'

View File

@@ -1,23 +1,23 @@
from typing import NotRequired, Unpack
from dataclasses import dataclass
from docutranslate.agents import AgentArgs, Agent
from docutranslate.agents import AgentConfig, Agent
class TXTTranslateAgentArgs(AgentArgs, total=True):
@dataclass
class TXTTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: NotRequired[str]
custom_prompt: str | None = None
class TXTTranslateAgent(Agent):
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]):
print(f"custom_prompt:{custom_prompt}")
super().__init__(**kwargs)
def __init__(self, config: TXTTranslateAgentConfig):
super().__init__(config)
self.system_prompt = f"""
# 角色
你是一个专业的机器翻译引擎
# 工作
翻译输入的txt文本
目标语言{to_lang}
目标语言{config.to_lang}
# 要求
翻译要求专业准确
不输出任何解释和注释
@@ -25,6 +25,6 @@ class TXTTranslateAgent(Agent):
# 输出
翻译后的txt译文纯文本
"""
if custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
if config.custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
self.system_prompt += r'\no_think'

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import Self
from docutranslate.agents import MDTranslateAgent
from docutranslate.agents.markdown_agent import MDTranslateAgentConfig
from docutranslate.context.md_mask_context import MDMaskUrisContext
from docutranslate.ir.markdown_document import MarkdownDocument
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
@@ -16,21 +17,21 @@ class MDTranslatorConfig(AiTranslatorConfig):
...
class MDTranslator(Translator):
def __init__(self, config: MDTranslatorConfig):
super().__init__(config=config)
self.chunk_size = config.chunk_size
self.translate_agent = MDTranslateAgent(custom_prompt=config.custom_prompt,
to_lang=config.to_lang,
baseurl=config.base_url,
key=config.api_key,
model_id=config.model_id,
system_prompt=None,
temperature=config.temperature,
max_concurrent=config.concurrent,
timeout=config.timeout,
logger=self.logger)
agent_config = MDTranslateAgentConfig(custom_prompt=config.custom_prompt,
to_lang=config.to_lang,
baseurl=config.base_url,
key=config.api_key,
model_id=config.model_id,
system_prompt=None,
temperature=config.temperature,
max_concurrent=config.concurrent,
timeout=config.timeout,
logger=self.logger)
self.translate_agent = MDTranslateAgent(agent_config)
def translate(self, document: MarkdownDocument) -> Self:
self.logger.info("正在翻译markdown")

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Self
from docutranslate.agents.txt_agent import TXTTranslateAgent
from docutranslate.agents.txt_agent import TXTTranslateAgent, TXTTranslateAgentConfig
from docutranslate.ir.document import Document
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
from docutranslate.translator.base import Translator
@@ -17,16 +17,17 @@ class TXTTranslator(Translator):
def __init__(self, config: TXTTranslatorConfig):
super().__init__(config=config)
self.chunk_size = config.chunk_size
self.translate_agent = TXTTranslateAgent(custom_prompt=config.custom_prompt,
to_lang=config.to_lang,
baseurl=config.base_url,
key=config.api_key,
model_id=config.model_id,
system_prompt=None,
temperature=config.temperature,
max_concurrent=config.concurrent,
timeout=config.timeout,
logger=self.logger)
agent_config = TXTTranslateAgentConfig(custom_prompt=config.custom_prompt,
to_lang=config.to_lang,
baseurl=config.base_url,
key=config.api_key,
model_id=config.model_id,
system_prompt=None,
temperature=config.temperature,
max_concurrent=config.concurrent,
timeout=config.timeout,
logger=self.logger)
self.translate_agent = TXTTranslateAgent(agent_config)
def translate(self, document: Document) -> Self:
self.logger.info("正在翻译txt")