初步建立agent术语表配置
This commit is contained in:
@@ -34,11 +34,11 @@ class AgentConfig:
|
||||
|
||||
|
||||
class TotalErrorCounter:
|
||||
def __init__(self, logger: logging.Logger,max_errors_count=10):
|
||||
def __init__(self, logger: logging.Logger, max_errors_count=10):
|
||||
self.lock = Lock()
|
||||
self.count = 0
|
||||
self.logger = logger
|
||||
self.max_errors_count=max_errors_count
|
||||
self.max_errors_count = max_errors_count
|
||||
|
||||
def add(self):
|
||||
self.lock.acquire()
|
||||
@@ -67,7 +67,7 @@ class PromptsCounter:
|
||||
self.lock.release()
|
||||
|
||||
|
||||
|
||||
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
||||
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
||||
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
||||
|
||||
@@ -128,10 +128,13 @@ class Agent:
|
||||
|
||||
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if pre_send_handler:
|
||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||
# if prompt.strip() == "":
|
||||
# return prompt
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
@@ -172,6 +175,7 @@ class Agent:
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
max_concurrent: int | None = None, # 新增参数,默认并发数为5
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
) -> list[Any]:
|
||||
@@ -179,7 +183,7 @@ class Agent:
|
||||
total = len(prompts)
|
||||
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||
self.total_error_counter.max_errors_count=len(prompts) // MAX_REQUESTS_PER_ERROR #允许多少个异常
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常
|
||||
count = 0
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
tasks = []
|
||||
@@ -194,6 +198,7 @@ class Agent:
|
||||
client=client,
|
||||
prompt=p_text,
|
||||
system_prompt=system_prompt,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
)
|
||||
@@ -210,9 +215,11 @@ class Agent:
|
||||
return results
|
||||
|
||||
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
result_handler=None, error_result_handler=None) -> Any:
|
||||
pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if pre_send_handler:
|
||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||
# if prompt.strip() == "":
|
||||
# return prompt
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
@@ -259,6 +266,7 @@ class Agent:
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
pre_send_handler:PreSendHandlerType=None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
) -> list[Any]:
|
||||
@@ -271,6 +279,7 @@ class Agent:
|
||||
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
|
||||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||
counters = itertools.repeat(counter, len(prompts))
|
||||
pre_send_handlers=itertools.repeat(pre_send_handler,len(prompts))
|
||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||
output_list = []
|
||||
@@ -279,6 +288,7 @@ class Agent:
|
||||
clients = itertools.repeat(client, len(prompts))
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
||||
pre_send_handlers,
|
||||
result_handlers,
|
||||
error_result_handlers)
|
||||
output_list = list(results_iterator)
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .agent import Agent, AgentConfig
|
||||
from ..glossary.glossary import Glossary
|
||||
|
||||
|
||||
@dataclass
|
||||
class MDTranslateAgentConfig(AgentConfig):
|
||||
to_lang:str
|
||||
custom_prompt:str|None=None
|
||||
to_lang: str
|
||||
custom_prompt: str | None = None
|
||||
glossary_dict: dict[str, str] | None = None
|
||||
|
||||
|
||||
class MDTranslateAgent(Agent):
|
||||
def __init__(self,config:MDTranslateAgentConfig):
|
||||
def __init__(self, config: MDTranslateAgentConfig):
|
||||
super().__init__(config)
|
||||
self.system_prompt = f"""
|
||||
# Role
|
||||
@@ -48,5 +52,17 @@ Output:
|
||||
这个方程是 $E=mc^2$。这很有名。
|
||||
$$1+1=2$$
|
||||
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
||||
self.custom_prompt = config.custom_prompt
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + config.custom_prompt + '\n'
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||
self.glossary_dict = config.glossary_dict
|
||||
|
||||
def _pre_send_handler(self, system_prompt, prompt):
|
||||
if self.glossary_dict:
|
||||
glossary = Glossary(glossary_dict=self.glossary_dict)
|
||||
system_prompt += glossary.append_system_prompt(prompt)
|
||||
return system_prompt, prompt
|
||||
def send_chunks(self, prompts: list[str]):
|
||||
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
|
||||
async def send_chunks_async(self, prompts: list[str]):
|
||||
return await super().send_prompts_async(prompts=prompts,pre_send_handler=self._pre_send_handler)
|
||||
@@ -7,6 +7,7 @@ from logging import Logger
|
||||
from json_repair import json_repair
|
||||
|
||||
from docutranslate.agents import AgentConfig, Agent
|
||||
from docutranslate.glossary.glossary import Glossary
|
||||
from docutranslate.utils.json_utils import segments2json_chunks
|
||||
|
||||
|
||||
@@ -14,6 +15,7 @@ from docutranslate.utils.json_utils import segments2json_chunks
|
||||
class SegmentsTranslateAgentConfig(AgentConfig):
|
||||
to_lang: str
|
||||
custom_prompt: str | None = None
|
||||
glossary_dict: dict[str, str] | None = None
|
||||
|
||||
|
||||
class SegmentsTranslateAgent(Agent):
|
||||
@@ -43,13 +45,21 @@ Output
|
||||
{r'{"0":"你好","1":"苹果","2":true,"3":"错误"}'}
|
||||
Warning: Never wrap the entire JSON object in quotes to make it a single string. Never wrap the JSON text in ```.
|
||||
"""
|
||||
self.custom_prompt = config.custom_prompt
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# **Important rules or background** for segments for translation \n" + config.custom_prompt + '\n'
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||
self.glossary_dict = config.glossary_dict
|
||||
|
||||
def _pre_send_handler(self, system_prompt, prompt):
|
||||
if self.glossary_dict:
|
||||
glossary = Glossary(glossary_dict=self.glossary_dict)
|
||||
system_prompt += glossary.append_system_prompt(prompt)
|
||||
return system_prompt, prompt
|
||||
|
||||
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||
try:
|
||||
result = json_repair.loads(result)
|
||||
if not isinstance(result,dict):
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("agent返回结果不是dict的json形式")
|
||||
except:
|
||||
logger.error("结果不能正确解析")
|
||||
@@ -66,7 +76,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
||||
def send_segments(self, segments: list[str], chunk_size: int):
|
||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
|
||||
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk in translated_chunks:
|
||||
@@ -102,7 +113,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||
chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk in translated_chunks:
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from docutranslate.agents import AgentConfig, Agent
|
||||
from docutranslate.glossary.glossary import Glossary
|
||||
|
||||
|
||||
@dataclass
|
||||
class TXTTranslateAgentConfig(AgentConfig):
|
||||
to_lang: str
|
||||
custom_prompt: str | None = None
|
||||
glossary_dict: dict[str, str] | None = None
|
||||
|
||||
|
||||
class TXTTranslateAgent(Agent):
|
||||
@@ -30,5 +32,17 @@ Target language: {config.to_lang}
|
||||
# Output
|
||||
The translated txt text as plain text.
|
||||
"""
|
||||
self.custom_prompt = config.custom_prompt
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# **Important rules or background**\n" + config.custom_prompt + '\n'
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||
self.glossary_dict = config.glossary_dict
|
||||
|
||||
def _pre_send_handler(self, system_prompt, prompt):
|
||||
if self.glossary_dict:
|
||||
glossary = Glossary(glossary_dict=self.glossary_dict)
|
||||
system_prompt += glossary.append_system_prompt(prompt)
|
||||
return system_prompt, prompt
|
||||
def send_chunks(self, prompts: list[str]):
|
||||
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
|
||||
async def send_chunks_async(self, prompts: list[str]):
|
||||
return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler)
|
||||
|
||||
0
docutranslate/glossary/__init__.py
Normal file
0
docutranslate/glossary/__init__.py
Normal file
15
docutranslate/glossary/glossary.py
Normal file
15
docutranslate/glossary/glossary.py
Normal file
@@ -0,0 +1,15 @@
|
||||
class Glossary:
|
||||
def __init__(self,glossary_dict:dict[str:str]|None=None):
|
||||
self.glossary_dict=glossary_dict
|
||||
|
||||
def update(self,update_dict:dict[str:str]):
|
||||
for src,dst in update_dict.items():
|
||||
if src not in self.glossary_dict:
|
||||
self.glossary_dict[src]=dst
|
||||
|
||||
def append_system_prompt(self,text:str):
|
||||
prompt="\n以下为参考术语表:\n"
|
||||
for src,dst in self.glossary_dict:
|
||||
if src in text:
|
||||
prompt+=f"{src}=>{dst}\n"
|
||||
prompt+="术语表结束\n"
|
||||
Reference in New Issue
Block a user