初步建立agent术语表配置

This commit is contained in:
xunbu
2025-08-27 21:05:30 +08:00
parent b1d68c2fc0
commit 78a4525108
6 changed files with 81 additions and 14 deletions

View File

@@ -67,7 +67,7 @@ class PromptsCounter:
self.lock.release()
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
ResultHandlerType = Callable[[str, str, logging.Logger], str]
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
@@ -128,10 +128,13 @@ class Agent:
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
@@ -172,6 +175,7 @@ class Agent:
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None, # 新增参数默认并发数为5
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
) -> list[Any]:
@@ -194,6 +198,7 @@ class Agent:
client=client,
prompt=p_text,
system_prompt=system_prompt,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
)
@@ -210,9 +215,11 @@ class Agent:
return results
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
result_handler=None, error_result_handler=None) -> Any:
pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
@@ -259,6 +266,7 @@ class Agent:
self,
prompts: list[str],
system_prompt: str | None = None,
pre_send_handler:PreSendHandlerType=None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
) -> list[Any]:
@@ -271,6 +279,7 @@ class Agent:
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
system_prompts = itertools.repeat(system_prompt, len(prompts))
counters = itertools.repeat(counter, len(prompts))
pre_send_handlers=itertools.repeat(pre_send_handler,len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts))
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
output_list = []
@@ -279,6 +288,7 @@ class Agent:
clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
pre_send_handlers,
result_handlers,
error_result_handlers)
output_list = list(results_iterator)

View File

@@ -1,11 +1,15 @@
from dataclasses import dataclass
from .agent import Agent, AgentConfig
from ..glossary.glossary import Glossary
@dataclass
class MDTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
class MDTranslateAgent(Agent):
def __init__(self, config: MDTranslateAgentConfig):
@@ -48,5 +52,17 @@ Output:
这个方程是 $E=mc^2$。这很有名。
$$1+1=2$$
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background** \n" + config.custom_prompt + '\n'
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt):
if self.glossary_dict:
glossary = Glossary(glossary_dict=self.glossary_dict)
system_prompt += glossary.append_system_prompt(prompt)
return system_prompt, prompt
def send_chunks(self, prompts: list[str]):
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
async def send_chunks_async(self, prompts: list[str]):
return await super().send_prompts_async(prompts=prompts,pre_send_handler=self._pre_send_handler)

View File

@@ -7,6 +7,7 @@ from logging import Logger
from json_repair import json_repair
from docutranslate.agents import AgentConfig, Agent
from docutranslate.glossary.glossary import Glossary
from docutranslate.utils.json_utils import segments2json_chunks
@@ -14,6 +15,7 @@ from docutranslate.utils.json_utils import segments2json_chunks
class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
class SegmentsTranslateAgent(Agent):
@@ -43,8 +45,16 @@ Output
{r'{"0":"你好","1":"苹果","2":true,"3":"错误"}'}
Warning: Never wrap the entire JSON object in quotes to make it a single string. Never wrap the JSON text in ```.
"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background** for segments for translation \n" + config.custom_prompt + '\n'
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt):
if self.glossary_dict:
glossary = Glossary(glossary_dict=self.glossary_dict)
system_prompt += glossary.append_system_prompt(prompt)
return system_prompt, prompt
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
try:
@@ -66,7 +76,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
def send_segments(self, segments: list[str], chunk_size: int):
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy()
for chunk in translated_chunks:
@@ -102,7 +113,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy()
for chunk in translated_chunks:

View File

@@ -1,12 +1,14 @@
from dataclasses import dataclass
from docutranslate.agents import AgentConfig, Agent
from docutranslate.glossary.glossary import Glossary
@dataclass
class TXTTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
class TXTTranslateAgent(Agent):
@@ -30,5 +32,17 @@ Target language: {config.to_lang}
# Output
The translated txt text as plain text.
"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background**\n" + config.custom_prompt + '\n'
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt):
if self.glossary_dict:
glossary = Glossary(glossary_dict=self.glossary_dict)
system_prompt += glossary.append_system_prompt(prompt)
return system_prompt, prompt
def send_chunks(self, prompts: list[str]):
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
async def send_chunks_async(self, prompts: list[str]):
return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler)

View File

View File

@@ -0,0 +1,15 @@
class Glossary:
def __init__(self,glossary_dict:dict[str:str]|None=None):
self.glossary_dict=glossary_dict
def update(self,update_dict:dict[str:str]):
for src,dst in update_dict.items():
if src not in self.glossary_dict:
self.glossary_dict[src]=dst
def append_system_prompt(self,text:str):
prompt="\n以下为参考术语表:\n"
for src,dst in self.glossary_dict:
if src in text:
prompt+=f"{src}=>{dst}\n"
prompt+="术语表结束\n"