初步建立agent术语表配置

This commit is contained in:
xunbu
2025-08-27 21:05:30 +08:00
parent b1d68c2fc0
commit 78a4525108
6 changed files with 81 additions and 14 deletions

View File

@@ -34,11 +34,11 @@ class AgentConfig:
class TotalErrorCounter:
def __init__(self, logger: logging.Logger,max_errors_count=10):
def __init__(self, logger: logging.Logger, max_errors_count=10):
self.lock = Lock()
self.count = 0
self.logger = logger
self.max_errors_count=max_errors_count
self.max_errors_count = max_errors_count
def add(self):
self.lock.acquire()
@@ -67,7 +67,7 @@ class PromptsCounter:
self.lock.release()
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
ResultHandlerType = Callable[[str, str, logging.Logger], str]
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
@@ -128,10 +128,13 @@ class Agent:
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
@@ -172,6 +175,7 @@ class Agent:
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None, # 新增参数默认并发数为5
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
) -> list[Any]:
@@ -179,7 +183,7 @@ class Agent:
total = len(prompts)
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count=len(prompts) // MAX_REQUESTS_PER_ERROR #允许多少个异常
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常
count = 0
semaphore = asyncio.Semaphore(max_concurrent)
tasks = []
@@ -194,6 +198,7 @@ class Agent:
client=client,
prompt=p_text,
system_prompt=system_prompt,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
)
@@ -210,9 +215,11 @@ class Agent:
return results
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
result_handler=None, error_result_handler=None) -> Any:
pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
@@ -259,6 +266,7 @@ class Agent:
self,
prompts: list[str],
system_prompt: str | None = None,
pre_send_handler:PreSendHandlerType=None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
) -> list[Any]:
@@ -271,6 +279,7 @@ class Agent:
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
system_prompts = itertools.repeat(system_prompt, len(prompts))
counters = itertools.repeat(counter, len(prompts))
pre_send_handlers=itertools.repeat(pre_send_handler,len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts))
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
output_list = []
@@ -279,6 +288,7 @@ class Agent:
clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
pre_send_handlers,
result_handlers,
error_result_handlers)
output_list = list(results_iterator)

View File

@@ -1,14 +1,18 @@
from dataclasses import dataclass
from .agent import Agent, AgentConfig
from ..glossary.glossary import Glossary
@dataclass
class MDTranslateAgentConfig(AgentConfig):
to_lang:str
custom_prompt:str|None=None
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
class MDTranslateAgent(Agent):
def __init__(self,config:MDTranslateAgentConfig):
def __init__(self, config: MDTranslateAgentConfig):
super().__init__(config)
self.system_prompt = f"""
# Role
@@ -48,5 +52,17 @@ Output:
这个方程是 $E=mc^2$。这很有名。
$$1+1=2$$
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background** \n" + config.custom_prompt + '\n'
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt):
if self.glossary_dict:
glossary = Glossary(glossary_dict=self.glossary_dict)
system_prompt += glossary.append_system_prompt(prompt)
return system_prompt, prompt
def send_chunks(self, prompts: list[str]):
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
async def send_chunks_async(self, prompts: list[str]):
return await super().send_prompts_async(prompts=prompts,pre_send_handler=self._pre_send_handler)

View File

@@ -7,6 +7,7 @@ from logging import Logger
from json_repair import json_repair
from docutranslate.agents import AgentConfig, Agent
from docutranslate.glossary.glossary import Glossary
from docutranslate.utils.json_utils import segments2json_chunks
@@ -14,6 +15,7 @@ from docutranslate.utils.json_utils import segments2json_chunks
class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
class SegmentsTranslateAgent(Agent):
@@ -43,13 +45,21 @@ Output
{r'{"0":"你好","1":"苹果","2":true,"3":"错误"}'}
Warning: Never wrap the entire JSON object in quotes to make it a single string. Never wrap the JSON text in ```.
"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background** for segments for translation \n" + config.custom_prompt + '\n'
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt):
if self.glossary_dict:
glossary = Glossary(glossary_dict=self.glossary_dict)
system_prompt += glossary.append_system_prompt(prompt)
return system_prompt, prompt
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
try:
result = json_repair.loads(result)
if not isinstance(result,dict):
if not isinstance(result, dict):
raise ValueError("agent返回结果不是dict的json形式")
except:
logger.error("结果不能正确解析")
@@ -66,7 +76,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
def send_segments(self, segments: list[str], chunk_size: int):
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy()
for chunk in translated_chunks:
@@ -102,7 +113,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy()
for chunk in translated_chunks:

View File

@@ -1,12 +1,14 @@
from dataclasses import dataclass
from docutranslate.agents import AgentConfig, Agent
from docutranslate.glossary.glossary import Glossary
@dataclass
class TXTTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
class TXTTranslateAgent(Agent):
@@ -30,5 +32,17 @@ Target language: {config.to_lang}
# Output
The translated txt text as plain text.
"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background**\n" + config.custom_prompt + '\n'
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt):
if self.glossary_dict:
glossary = Glossary(glossary_dict=self.glossary_dict)
system_prompt += glossary.append_system_prompt(prompt)
return system_prompt, prompt
def send_chunks(self, prompts: list[str]):
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
async def send_chunks_async(self, prompts: list[str]):
return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler)

View File

View File

@@ -0,0 +1,15 @@
class Glossary:
def __init__(self,glossary_dict:dict[str:str]|None=None):
self.glossary_dict=glossary_dict
def update(self,update_dict:dict[str:str]):
for src,dst in update_dict.items():
if src not in self.glossary_dict:
self.glossary_dict[src]=dst
def append_system_prompt(self,text:str):
prompt="\n以下为参考术语表:\n"
for src,dst in self.glossary_dict:
if src in text:
prompt+=f"{src}=>{dst}\n"
prompt+="术语表结束\n"