agent配置化
This commit is contained in:
@@ -1,2 +1,2 @@
|
|||||||
from .agent import Agent, AgentArgs
|
from .agent import Agent, AgentConfig
|
||||||
from .markdown_agent import MDRefineAgent, MDTranslateAgent
|
from .markdown_agent import MDTranslateAgent
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -13,15 +13,16 @@ MAX_RETRY_COUNT = 2
|
|||||||
MAX_TOTAL_ERROR_COUNT = 10
|
MAX_TOTAL_ERROR_COUNT = 10
|
||||||
|
|
||||||
|
|
||||||
class AgentArgs(TypedDict, total=False):
|
@dataclass(kw_only=True)
|
||||||
|
class AgentConfig:
|
||||||
|
logger: logging.Logger
|
||||||
baseurl: str
|
baseurl: str
|
||||||
key: str
|
key: str
|
||||||
model_id: str
|
model_id: str
|
||||||
system_prompt: str | None
|
system_prompt: str | None
|
||||||
temperature: float
|
temperature: float = 0.7
|
||||||
max_concurrent: int
|
max_concurrent: int = 30
|
||||||
timeout: int
|
timeout: int = 2000
|
||||||
logger: logging.Logger
|
|
||||||
|
|
||||||
|
|
||||||
class TotalErrorCounter:
|
class TotalErrorCounter:
|
||||||
@@ -61,21 +62,20 @@ TIMEOUT = 600
|
|||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, baseurl: str, key: str | None, model_id: str, system_prompt: str | None = None, temperature=0.7,
|
def __init__(self, config: AgentConfig):
|
||||||
max_concurrent=15, timeout: int = TIMEOUT, logger: logging.Logger | None = None):
|
self.baseurl = config.baseurl.strip()
|
||||||
self.baseurl = baseurl.strip()
|
|
||||||
if self.baseurl.endswith("/"):
|
if self.baseurl.endswith("/"):
|
||||||
self.baseurl = self.baseurl[:-1]
|
self.baseurl = self.baseurl[:-1]
|
||||||
self.key = key.strip() or "xx"
|
self.key = config.key.strip() or "xx"
|
||||||
self.model_id = model_id.strip()
|
self.model_id = config.model_id.strip()
|
||||||
self.system_prompt = system_prompt or ""
|
self.system_prompt = config.system_prompt or ""
|
||||||
self.temperature = temperature
|
self.temperature = config.temperature
|
||||||
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
|
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
|
||||||
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
||||||
self.max_concurrent = max_concurrent
|
self.max_concurrent = config.max_concurrent
|
||||||
self.timeout = timeout
|
self.timeout = config.timeout
|
||||||
|
|
||||||
self.logger = logger if logger else global_logger
|
self.logger = config.logger or global_logger
|
||||||
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
||||||
|
|
||||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||||
|
|||||||
@@ -1,60 +1,21 @@
|
|||||||
from typing import Unpack, NotRequired
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from .agent import Agent, AgentArgs
|
from .agent import Agent, AgentConfig
|
||||||
|
|
||||||
class MDTranslateAgentArgs(AgentArgs, total=True):
|
@dataclass
|
||||||
|
class MDTranslateAgentConfig(AgentConfig):
|
||||||
to_lang:str
|
to_lang:str
|
||||||
custom_prompt:NotRequired[str]
|
custom_prompt:str|None=None
|
||||||
|
|
||||||
class MDRefineAgent(Agent):
|
|
||||||
def __init__(self, custom_prompt=None, **kwargs: Unpack[AgentArgs]):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.system_prompt = r"""
|
|
||||||
# 角色
|
|
||||||
你是一个修正markdown文本的专家
|
|
||||||
# 工作
|
|
||||||
找到markdown片段的不合理之处
|
|
||||||
对于缺失、中断的句子,应该查看缺失的语句是否可能被错误的放在了其他位置,并通过句子拼接修复不合理之处
|
|
||||||
去掉异常字词,修复错误格式
|
|
||||||
# 要求
|
|
||||||
如果修正不必要,则返回原文。
|
|
||||||
不要解释,不要注释。
|
|
||||||
不要修改标题的级别(如一级标题不要修改为二级标题)
|
|
||||||
形如<ph-ads231>的占位符不要改变【非常重要】
|
|
||||||
code、latex和HTML保持结构
|
|
||||||
所有公式(包括短公式)都应该是latex公式
|
|
||||||
公式无论长短必须表示为能被解析的合法latex公式,公式需被$或\\(\\)或$$正确包裹,如不正确则进行修正
|
|
||||||
# 输出
|
|
||||||
修正后的markdown纯文本(不是markdown代码块)
|
|
||||||
# 示例
|
|
||||||
## 修正文本流
|
|
||||||
输入:
|
|
||||||
什么名字
|
|
||||||
你叫
|
|
||||||
输出:
|
|
||||||
你叫什么名字
|
|
||||||
## 去掉异常字词与修正公式(行内公式使用$包裹)
|
|
||||||
输入:
|
|
||||||
一道\题@#目<ph-12asd2>:c_0+1=2,\(c 0\)等于几
|
|
||||||
{c_0,c_1,c^2}是一个集合
|
|
||||||
输出:
|
|
||||||
一道题目<ph-12asd2>:$c_0+1=2$,$c_0$等于几
|
|
||||||
{$c_0$,$c_1$,$c^2$}是一个集合"""
|
|
||||||
if custom_prompt:
|
|
||||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
|
|
||||||
self.system_prompt += r'\no_think'
|
|
||||||
|
|
||||||
|
|
||||||
class MDTranslateAgent(Agent):
|
class MDTranslateAgent(Agent):
|
||||||
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]):
|
def __init__(self,config:MDTranslateAgentConfig):
|
||||||
print(f"custom_prompt:{custom_prompt}")
|
super().__init__(config)
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
# 角色
|
# 角色
|
||||||
你是一个专业的机器翻译引擎
|
你是一个专业的机器翻译引擎
|
||||||
# 工作
|
# 工作
|
||||||
翻译输入的markdown文本
|
翻译输入的markdown文本
|
||||||
目标语言{to_lang}
|
目标语言{config.to_lang}
|
||||||
# 要求
|
# 要求
|
||||||
翻译要求专业准确
|
翻译要求专业准确
|
||||||
不输出任何解释和注释
|
不输出任何解释和注释
|
||||||
@@ -81,6 +42,6 @@ The equation is E=mc 2. This is famous.
|
|||||||
这个方程是 $E=mc^2$。这很有名。
|
这个方程是 $E=mc^2$。这很有名。
|
||||||
$$1+1=2$$
|
$$1+1=2$$
|
||||||
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
||||||
if custom_prompt:
|
if config.custom_prompt:
|
||||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
|
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||||
self.system_prompt += r'\no_think'
|
self.system_prompt += r'\no_think'
|
||||||
|
|||||||
@@ -1,23 +1,23 @@
|
|||||||
from typing import NotRequired, Unpack
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from docutranslate.agents import AgentArgs, Agent
|
from docutranslate.agents import AgentConfig, Agent
|
||||||
|
|
||||||
|
|
||||||
class TXTTranslateAgentArgs(AgentArgs, total=True):
|
@dataclass
|
||||||
|
class TXTTranslateAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: NotRequired[str]
|
custom_prompt: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class TXTTranslateAgent(Agent):
|
class TXTTranslateAgent(Agent):
|
||||||
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]):
|
def __init__(self, config: TXTTranslateAgentConfig):
|
||||||
print(f"custom_prompt:{custom_prompt}")
|
super().__init__(config)
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
# 角色
|
# 角色
|
||||||
你是一个专业的机器翻译引擎
|
你是一个专业的机器翻译引擎
|
||||||
# 工作
|
# 工作
|
||||||
翻译输入的txt文本
|
翻译输入的txt文本
|
||||||
目标语言{to_lang}
|
目标语言{config.to_lang}
|
||||||
# 要求
|
# 要求
|
||||||
翻译要求专业准确
|
翻译要求专业准确
|
||||||
不输出任何解释和注释
|
不输出任何解释和注释
|
||||||
@@ -25,6 +25,6 @@ class TXTTranslateAgent(Agent):
|
|||||||
# 输出
|
# 输出
|
||||||
翻译后的txt译文纯文本
|
翻译后的txt译文纯文本
|
||||||
"""
|
"""
|
||||||
if custom_prompt:
|
if config.custom_prompt:
|
||||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
|
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||||
self.system_prompt += r'\no_think'
|
self.system_prompt += r'\no_think'
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from docutranslate.agents import MDTranslateAgent
|
from docutranslate.agents import MDTranslateAgent
|
||||||
|
from docutranslate.agents.markdown_agent import MDTranslateAgentConfig
|
||||||
from docutranslate.context.md_mask_context import MDMaskUrisContext
|
from docutranslate.context.md_mask_context import MDMaskUrisContext
|
||||||
from docutranslate.ir.markdown_document import MarkdownDocument
|
from docutranslate.ir.markdown_document import MarkdownDocument
|
||||||
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
|
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
|
||||||
@@ -16,21 +17,21 @@ class MDTranslatorConfig(AiTranslatorConfig):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MDTranslator(Translator):
|
class MDTranslator(Translator):
|
||||||
def __init__(self, config: MDTranslatorConfig):
|
def __init__(self, config: MDTranslatorConfig):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
self.chunk_size = config.chunk_size
|
self.chunk_size = config.chunk_size
|
||||||
self.translate_agent = MDTranslateAgent(custom_prompt=config.custom_prompt,
|
agent_config = MDTranslateAgentConfig(custom_prompt=config.custom_prompt,
|
||||||
to_lang=config.to_lang,
|
to_lang=config.to_lang,
|
||||||
baseurl=config.base_url,
|
baseurl=config.base_url,
|
||||||
key=config.api_key,
|
key=config.api_key,
|
||||||
model_id=config.model_id,
|
model_id=config.model_id,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
max_concurrent=config.concurrent,
|
max_concurrent=config.concurrent,
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
logger=self.logger)
|
logger=self.logger)
|
||||||
|
self.translate_agent = MDTranslateAgent(agent_config)
|
||||||
|
|
||||||
def translate(self, document: MarkdownDocument) -> Self:
|
def translate(self, document: MarkdownDocument) -> Self:
|
||||||
self.logger.info("正在翻译markdown")
|
self.logger.info("正在翻译markdown")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from docutranslate.agents.txt_agent import TXTTranslateAgent
|
from docutranslate.agents.txt_agent import TXTTranslateAgent, TXTTranslateAgentConfig
|
||||||
from docutranslate.ir.document import Document
|
from docutranslate.ir.document import Document
|
||||||
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
|
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
|
||||||
from docutranslate.translator.base import Translator
|
from docutranslate.translator.base import Translator
|
||||||
@@ -17,16 +17,17 @@ class TXTTranslator(Translator):
|
|||||||
def __init__(self, config: TXTTranslatorConfig):
|
def __init__(self, config: TXTTranslatorConfig):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
self.chunk_size = config.chunk_size
|
self.chunk_size = config.chunk_size
|
||||||
self.translate_agent = TXTTranslateAgent(custom_prompt=config.custom_prompt,
|
agent_config = TXTTranslateAgentConfig(custom_prompt=config.custom_prompt,
|
||||||
to_lang=config.to_lang,
|
to_lang=config.to_lang,
|
||||||
baseurl=config.base_url,
|
baseurl=config.base_url,
|
||||||
key=config.api_key,
|
key=config.api_key,
|
||||||
model_id=config.model_id,
|
model_id=config.model_id,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
max_concurrent=config.concurrent,
|
max_concurrent=config.concurrent,
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
logger=self.logger)
|
logger=self.logger)
|
||||||
|
self.translate_agent = TXTTranslateAgent(agent_config)
|
||||||
|
|
||||||
def translate(self, document: Document) -> Self:
|
def translate(self, document: Document) -> Self:
|
||||||
self.logger.info("正在翻译txt")
|
self.logger.info("正在翻译txt")
|
||||||
|
|||||||
Reference in New Issue
Block a user