agent配置化
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .agent import Agent, AgentArgs
|
||||
from .markdown_agent import MDRefineAgent, MDTranslateAgent
|
||||
from .agent import Agent, AgentConfig
|
||||
from .markdown_agent import MDTranslateAgent
|
||||
|
||||
@@ -2,8 +2,8 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -13,15 +13,16 @@ MAX_RETRY_COUNT = 2
|
||||
MAX_TOTAL_ERROR_COUNT = 10
|
||||
|
||||
|
||||
class AgentArgs(TypedDict, total=False):
|
||||
@dataclass(kw_only=True)
|
||||
class AgentConfig:
|
||||
logger: logging.Logger
|
||||
baseurl: str
|
||||
key: str
|
||||
model_id: str
|
||||
system_prompt: str | None
|
||||
temperature: float
|
||||
max_concurrent: int
|
||||
timeout: int
|
||||
logger: logging.Logger
|
||||
temperature: float = 0.7
|
||||
max_concurrent: int = 30
|
||||
timeout: int = 2000
|
||||
|
||||
|
||||
class TotalErrorCounter:
|
||||
@@ -61,21 +62,20 @@ TIMEOUT = 600
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, baseurl: str, key: str | None, model_id: str, system_prompt: str | None = None, temperature=0.7,
|
||||
max_concurrent=15, timeout: int = TIMEOUT, logger: logging.Logger | None = None):
|
||||
self.baseurl = baseurl.strip()
|
||||
def __init__(self, config: AgentConfig):
|
||||
self.baseurl = config.baseurl.strip()
|
||||
if self.baseurl.endswith("/"):
|
||||
self.baseurl = self.baseurl[:-1]
|
||||
self.key = key.strip() or "xx"
|
||||
self.model_id = model_id.strip()
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.temperature = temperature
|
||||
self.key = config.key.strip() or "xx"
|
||||
self.model_id = config.model_id.strip()
|
||||
self.system_prompt = config.system_prompt or ""
|
||||
self.temperature = config.temperature
|
||||
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
|
||||
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
||||
self.max_concurrent = max_concurrent
|
||||
self.timeout = timeout
|
||||
self.max_concurrent = config.max_concurrent
|
||||
self.timeout = config.timeout
|
||||
|
||||
self.logger = logger if logger else global_logger
|
||||
self.logger = config.logger or global_logger
|
||||
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
||||
|
||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||
|
||||
@@ -1,60 +1,21 @@
|
||||
from typing import Unpack, NotRequired
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .agent import Agent, AgentArgs
|
||||
from .agent import Agent, AgentConfig
|
||||
|
||||
class MDTranslateAgentArgs(AgentArgs, total=True):
|
||||
@dataclass
|
||||
class MDTranslateAgentConfig(AgentConfig):
|
||||
to_lang:str
|
||||
custom_prompt:NotRequired[str]
|
||||
|
||||
class MDRefineAgent(Agent):
|
||||
def __init__(self, custom_prompt=None, **kwargs: Unpack[AgentArgs]):
|
||||
super().__init__(**kwargs)
|
||||
self.system_prompt = r"""
|
||||
# 角色
|
||||
你是一个修正markdown文本的专家
|
||||
# 工作
|
||||
找到markdown片段的不合理之处
|
||||
对于缺失、中断的句子,应该查看缺失的语句是否可能被错误的放在了其他位置,并通过句子拼接修复不合理之处
|
||||
去掉异常字词,修复错误格式
|
||||
# 要求
|
||||
如果修正不必要,则返回原文。
|
||||
不要解释,不要注释。
|
||||
不要修改标题的级别(如一级标题不要修改为二级标题)
|
||||
形如<ph-ads231>的占位符不要改变【非常重要】
|
||||
code、latex和HTML保持结构
|
||||
所有公式(包括短公式)都应该是latex公式
|
||||
公式无论长短必须表示为能被解析的合法latex公式,公式需被$或\\(\\)或$$正确包裹,如不正确则进行修正
|
||||
# 输出
|
||||
修正后的markdown纯文本(不是markdown代码块)
|
||||
# 示例
|
||||
## 修正文本流
|
||||
输入:
|
||||
什么名字
|
||||
你叫
|
||||
输出:
|
||||
你叫什么名字
|
||||
## 去掉异常字词与修正公式(行内公式使用$包裹)
|
||||
输入:
|
||||
一道\题@#目<ph-12asd2>:c_0+1=2,\(c 0\)等于几
|
||||
{c_0,c_1,c^2}是一个集合
|
||||
输出:
|
||||
一道题目<ph-12asd2>:$c_0+1=2$,$c_0$等于几
|
||||
{$c_0$,$c_1$,$c^2$}是一个集合"""
|
||||
if custom_prompt:
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
|
||||
self.system_prompt += r'\no_think'
|
||||
|
||||
custom_prompt:str|None=None
|
||||
|
||||
class MDTranslateAgent(Agent):
|
||||
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]):
|
||||
print(f"custom_prompt:{custom_prompt}")
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self,config:MDTranslateAgentConfig):
|
||||
super().__init__(config)
|
||||
self.system_prompt = f"""
|
||||
# 角色
|
||||
你是一个专业的机器翻译引擎
|
||||
# 工作
|
||||
翻译输入的markdown文本
|
||||
目标语言{to_lang}
|
||||
目标语言{config.to_lang}
|
||||
# 要求
|
||||
翻译要求专业准确
|
||||
不输出任何解释和注释
|
||||
@@ -81,6 +42,6 @@ The equation is E=mc 2. This is famous.
|
||||
这个方程是 $E=mc^2$。这很有名。
|
||||
$$1+1=2$$
|
||||
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
||||
if custom_prompt:
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||
self.system_prompt += r'\no_think'
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
from typing import NotRequired, Unpack
|
||||
from dataclasses import dataclass
|
||||
|
||||
from docutranslate.agents import AgentArgs, Agent
|
||||
from docutranslate.agents import AgentConfig, Agent
|
||||
|
||||
|
||||
class TXTTranslateAgentArgs(AgentArgs, total=True):
|
||||
@dataclass
|
||||
class TXTTranslateAgentConfig(AgentConfig):
|
||||
to_lang: str
|
||||
custom_prompt: NotRequired[str]
|
||||
custom_prompt: str | None = None
|
||||
|
||||
|
||||
class TXTTranslateAgent(Agent):
|
||||
def __init__(self, custom_prompt=None, to_lang="中文", **kwargs: Unpack[AgentArgs]):
|
||||
print(f"custom_prompt:{custom_prompt}")
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, config: TXTTranslateAgentConfig):
|
||||
super().__init__(config)
|
||||
self.system_prompt = f"""
|
||||
# 角色
|
||||
你是一个专业的机器翻译引擎
|
||||
# 工作
|
||||
翻译输入的txt文本
|
||||
目标语言{to_lang}
|
||||
目标语言{config.to_lang}
|
||||
# 要求
|
||||
翻译要求专业准确
|
||||
不输出任何解释和注释
|
||||
@@ -25,6 +25,6 @@ class TXTTranslateAgent(Agent):
|
||||
# 输出
|
||||
翻译后的txt译文纯文本
|
||||
"""
|
||||
if custom_prompt:
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + custom_prompt + '\n'
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||
self.system_prompt += r'\no_think'
|
||||
|
||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from docutranslate.agents import MDTranslateAgent
|
||||
from docutranslate.agents.markdown_agent import MDTranslateAgentConfig
|
||||
from docutranslate.context.md_mask_context import MDMaskUrisContext
|
||||
from docutranslate.ir.markdown_document import MarkdownDocument
|
||||
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
|
||||
@@ -16,21 +17,21 @@ class MDTranslatorConfig(AiTranslatorConfig):
|
||||
...
|
||||
|
||||
|
||||
|
||||
class MDTranslator(Translator):
|
||||
def __init__(self, config: MDTranslatorConfig):
|
||||
super().__init__(config=config)
|
||||
self.chunk_size = config.chunk_size
|
||||
self.translate_agent = MDTranslateAgent(custom_prompt=config.custom_prompt,
|
||||
to_lang=config.to_lang,
|
||||
baseurl=config.base_url,
|
||||
key=config.api_key,
|
||||
model_id=config.model_id,
|
||||
system_prompt=None,
|
||||
temperature=config.temperature,
|
||||
max_concurrent=config.concurrent,
|
||||
timeout=config.timeout,
|
||||
logger=self.logger)
|
||||
agent_config = MDTranslateAgentConfig(custom_prompt=config.custom_prompt,
|
||||
to_lang=config.to_lang,
|
||||
baseurl=config.base_url,
|
||||
key=config.api_key,
|
||||
model_id=config.model_id,
|
||||
system_prompt=None,
|
||||
temperature=config.temperature,
|
||||
max_concurrent=config.concurrent,
|
||||
timeout=config.timeout,
|
||||
logger=self.logger)
|
||||
self.translate_agent = MDTranslateAgent(agent_config)
|
||||
|
||||
def translate(self, document: MarkdownDocument) -> Self:
|
||||
self.logger.info("正在翻译markdown")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from docutranslate.agents.txt_agent import TXTTranslateAgent
|
||||
from docutranslate.agents.txt_agent import TXTTranslateAgent, TXTTranslateAgentConfig
|
||||
from docutranslate.ir.document import Document
|
||||
from docutranslate.translator.ai_translator.base import AiTranslatorConfig
|
||||
from docutranslate.translator.base import Translator
|
||||
@@ -17,16 +17,17 @@ class TXTTranslator(Translator):
|
||||
def __init__(self, config: TXTTranslatorConfig):
|
||||
super().__init__(config=config)
|
||||
self.chunk_size = config.chunk_size
|
||||
self.translate_agent = TXTTranslateAgent(custom_prompt=config.custom_prompt,
|
||||
to_lang=config.to_lang,
|
||||
baseurl=config.base_url,
|
||||
key=config.api_key,
|
||||
model_id=config.model_id,
|
||||
system_prompt=None,
|
||||
temperature=config.temperature,
|
||||
max_concurrent=config.concurrent,
|
||||
timeout=config.timeout,
|
||||
logger=self.logger)
|
||||
agent_config = TXTTranslateAgentConfig(custom_prompt=config.custom_prompt,
|
||||
to_lang=config.to_lang,
|
||||
baseurl=config.base_url,
|
||||
key=config.api_key,
|
||||
model_id=config.model_id,
|
||||
system_prompt=None,
|
||||
temperature=config.temperature,
|
||||
max_concurrent=config.concurrent,
|
||||
timeout=config.timeout,
|
||||
logger=self.logger)
|
||||
self.translate_agent = TXTTranslateAgent(agent_config)
|
||||
|
||||
def translate(self, document: Document) -> Self:
|
||||
self.logger.info("正在翻译txt")
|
||||
|
||||
Reference in New Issue
Block a user