agent配置化

This commit is contained in:
xunbu
2025-08-01 13:19:48 +08:00
parent 89b1963b97
commit 190ba01430
6 changed files with 62 additions and 99 deletions

View File

@@ -1,2 +1,2 @@
from .agent import Agent, AgentArgs from .agent import Agent, AgentConfig
from .markdown_agent import MDRefineAgent, MDTranslateAgent from .markdown_agent import MDTranslateAgent

View File

@@ -2,8 +2,8 @@ import asyncio
import logging import logging
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import TypedDict
import httpx import httpx
@@ -13,15 +13,16 @@ MAX_RETRY_COUNT = 2
MAX_TOTAL_ERROR_COUNT = 10 MAX_TOTAL_ERROR_COUNT = 10
class AgentArgs(TypedDict, total=False): @dataclass(kw_only=True)
class AgentConfig:
logger: logging.Logger
baseurl: str baseurl: str
key: str key: str
model_id: str model_id: str
system_prompt: str | None system_prompt: str | None
temperature: float temperature: float = 0.7
max_concurrent: int max_concurrent: int = 30
timeout: int timeout: int = 2000
logger: logging.Logger
class TotalErrorCounter: class TotalErrorCounter:
@@ -61,21 +62,20 @@ TIMEOUT = 600
class Agent: class Agent:
def __init__(self, baseurl: str, key: str | None, model_id: str, system_prompt: str | None = None, temperature=0.7, def __init__(self, config: AgentConfig):
max_concurrent=15, timeout: int = TIMEOUT, logger: logging.Logger | None = None): self.baseurl = config.baseurl.strip()
self.baseurl = baseurl.strip()
if self.baseurl.endswith("/"): if self.baseurl.endswith("/"):
self.baseurl = self.baseurl[:-1] self.baseurl = self.baseurl[:-1]
self.key = key.strip() or "xx" self.key = config.key.strip() or "xx"
self.model_id = model_id.strip() self.model_id = config.model_id.strip()
self.system_prompt = system_prompt or "" self.system_prompt = config.system_prompt or ""
self.temperature = temperature self.temperature = config.temperature
self.client = httpx.Client(trust_env=False, proxy=None, verify=False) self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False) self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
self.max_concurrent = max_concurrent self.max_concurrent = config.max_concurrent
self.timeout = timeout self.timeout = config.timeout
self.logger = logger if logger else global_logger self.logger = config.logger or global_logger
self.total_error_counter = TotalErrorCounter(logger=self.logger) self.total_error_counter = TotalErrorCounter(logger=self.logger)
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):

View File

@@ -1,60 +1,21 @@
from typing import Unpack, NotRequired from dataclasses import dataclass
from .agent import Agent, AgentArgs from .agent import Agent, AgentConfig
class MDTranslateAgentArgs(AgentArgs, total=True): @dataclass
class MDTranslateAgentConfig(AgentConfig):
to_lang:str to_lang:str
custom_prompt:NotRequired[str] custom_prompt:str|None=None
class MDRefineAgent(Agent):
def __init__(self, custom_prompt=None, **kwargs: Unpack[AgentArgs]):
super().__init__(**kwargs)
self.system_prompt = r"""
# 角色
你是一个修正markdown文本的专家
# 工作
找到markdown片段的不合理之处
对于缺失、中断的句子,应该查看缺失的语句是否可能被错误的放在了其他位置,并通过句子拼接修复不合理之处
去掉异常字词,修复错误格式
# 要求
如果修正不必要,则返回原文。
不要解释,不要注释。
不要修改标题的级别(如一级标题不要修改为二级标题)
形如<ph-ads231>的占位符不要改变【非常重要】
code、latex和HTML保持结构
所有公式包括短公式都应该是latex公式
公式无论长短必须表示为能被解析的合法latex公式公式需被$或\\(\\)或$$正确包裹,如不正确则进行修正
# 输出
修正后的markdown纯文本不是markdown代码块
# 示例
## 修正文本流
输入:
什么名字
你叫
输出:
你叫什么名字
## 去掉异常字词与修正公式(行内公式使用$包裹)
输入:
一道\题@#目<ph-12asd2>:c_0+1=2\(c 0\)等于几
{c_0,c_1,c^2}是一个集合
输出:
一道题目<ph-12asd2>:$c_0+1=2$$c_0$等于几
{$c_0$,$c_1$,$c^2$}是一个集合"""
if custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
self.system_prompt += r'\no_think'
class MDTranslateAgent(Agent): class MDTranslateAgent(Agent):
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]): def __init__(self,config:MDTranslateAgentConfig):
print(f"custom_prompt:{custom_prompt}") super().__init__(config)
super().__init__(**kwargs)
self.system_prompt = f""" self.system_prompt = f"""
# 角色 # 角色
你是一个专业的机器翻译引擎 你是一个专业的机器翻译引擎
# 工作 # 工作
翻译输入的markdown文本 翻译输入的markdown文本
目标语言{to_lang} 目标语言{config.to_lang}
# 要求 # 要求
翻译要求专业准确 翻译要求专业准确
不输出任何解释和注释 不输出任何解释和注释
@@ -81,6 +42,6 @@ The equation is E=mc 2. This is famous.
这个方程是 $E=mc^2$。这很有名。 这个方程是 $E=mc^2$。这很有名。
$$1+1=2$$ $$1+1=2$$
\\((c_0,c_1,c_2^2)\\)是一个坐标。""" \\((c_0,c_1,c_2^2)\\)是一个坐标。"""
if custom_prompt: if config.custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n' self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
self.system_prompt += r'\no_think' self.system_prompt += r'\no_think'

View File

@@ -1,23 +1,23 @@
from typing import NotRequired, Unpack from dataclasses import dataclass
from docutranslate.agents import AgentArgs, Agent from docutranslate.agents import AgentConfig, Agent
class TXTTranslateAgentArgs(AgentArgs, total=True): @dataclass
class TXTTranslateAgentConfig(AgentConfig):
to_lang: str to_lang: str
custom_prompt: NotRequired[str] custom_prompt: str | None = None
class TXTTranslateAgent(Agent): class TXTTranslateAgent(Agent):
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]): def __init__(self, config: TXTTranslateAgentConfig):
print(f"custom_prompt:{custom_prompt}") super().__init__(config)
super().__init__(**kwargs)
self.system_prompt = f""" self.system_prompt = f"""
# 角色 # 角色
你是一个专业的机器翻译引擎 你是一个专业的机器翻译引擎
# 工作 # 工作
翻译输入的txt文本 翻译输入的txt文本
目标语言{to_lang} 目标语言{config.to_lang}
# 要求 # 要求
翻译要求专业准确 翻译要求专业准确
不输出任何解释和注释 不输出任何解释和注释
@@ -25,6 +25,6 @@ class TXTTranslateAgent(Agent):
# 输出 # 输出
翻译后的txt译文纯文本 翻译后的txt译文纯文本
""" """
if custom_prompt: if config.custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n' self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
self.system_prompt += r'\no_think' self.system_prompt += r'\no_think'

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import Self from typing import Self
from docutranslate.agents import MDTranslateAgent from docutranslate.agents import MDTranslateAgent
from docutranslate.agents.markdown_agent import MDTranslateAgentConfig
from docutranslate.context.md_mask_context import MDMaskUrisContext from docutranslate.context.md_mask_context import MDMaskUrisContext
from docutranslate.ir.markdown_document import MarkdownDocument from docutranslate.ir.markdown_document import MarkdownDocument
from docutranslate.translator.ai_translator.base import AiTranslatorConfig from docutranslate.translator.ai_translator.base import AiTranslatorConfig
@@ -16,12 +17,11 @@ class MDTranslatorConfig(AiTranslatorConfig):
... ...
class MDTranslator(Translator): class MDTranslator(Translator):
def __init__(self, config: MDTranslatorConfig): def __init__(self, config: MDTranslatorConfig):
super().__init__(config=config) super().__init__(config=config)
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.translate_agent = MDTranslateAgent(custom_prompt=config.custom_prompt, agent_config = MDTranslateAgentConfig(custom_prompt=config.custom_prompt,
to_lang=config.to_lang, to_lang=config.to_lang,
baseurl=config.base_url, baseurl=config.base_url,
key=config.api_key, key=config.api_key,
@@ -31,6 +31,7 @@ class MDTranslator(Translator):
max_concurrent=config.concurrent, max_concurrent=config.concurrent,
timeout=config.timeout, timeout=config.timeout,
logger=self.logger) logger=self.logger)
self.translate_agent = MDTranslateAgent(agent_config)
def translate(self, document: MarkdownDocument) -> Self: def translate(self, document: MarkdownDocument) -> Self:
self.logger.info("正在翻译markdown") self.logger.info("正在翻译markdown")

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Self from typing import Self
from docutranslate.agents.txt_agent import TXTTranslateAgent from docutranslate.agents.txt_agent import TXTTranslateAgent, TXTTranslateAgentConfig
from docutranslate.ir.document import Document from docutranslate.ir.document import Document
from docutranslate.translator.ai_translator.base import AiTranslatorConfig from docutranslate.translator.ai_translator.base import AiTranslatorConfig
from docutranslate.translator.base import Translator from docutranslate.translator.base import Translator
@@ -17,7 +17,7 @@ class TXTTranslator(Translator):
def __init__(self, config: TXTTranslatorConfig): def __init__(self, config: TXTTranslatorConfig):
super().__init__(config=config) super().__init__(config=config)
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.translate_agent = TXTTranslateAgent(custom_prompt=config.custom_prompt, agent_config = TXTTranslateAgentConfig(custom_prompt=config.custom_prompt,
to_lang=config.to_lang, to_lang=config.to_lang,
baseurl=config.base_url, baseurl=config.base_url,
key=config.api_key, key=config.api_key,
@@ -27,6 +27,7 @@ class TXTTranslator(Translator):
max_concurrent=config.concurrent, max_concurrent=config.concurrent,
timeout=config.timeout, timeout=config.timeout,
logger=self.logger) logger=self.logger)
self.translate_agent = TXTTranslateAgent(agent_config)
def translate(self, document: Document) -> Self: def translate(self, document: Document) -> Self:
self.logger.info("正在翻译txt") self.logger.info("正在翻译txt")