初步建立agent术语表配置
This commit is contained in:
@@ -34,11 +34,11 @@ class AgentConfig:
|
|||||||
|
|
||||||
|
|
||||||
class TotalErrorCounter:
|
class TotalErrorCounter:
|
||||||
def __init__(self, logger: logging.Logger,max_errors_count=10):
|
def __init__(self, logger: logging.Logger, max_errors_count=10):
|
||||||
self.lock = Lock()
|
self.lock = Lock()
|
||||||
self.count = 0
|
self.count = 0
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.max_errors_count=max_errors_count
|
self.max_errors_count = max_errors_count
|
||||||
|
|
||||||
def add(self):
|
def add(self):
|
||||||
self.lock.acquire()
|
self.lock.acquire()
|
||||||
@@ -67,7 +67,7 @@ class PromptsCounter:
|
|||||||
self.lock.release()
|
self.lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
||||||
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
||||||
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
||||||
|
|
||||||
@@ -128,10 +128,13 @@ class Agent:
|
|||||||
|
|
||||||
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
||||||
retry_count=0,
|
retry_count=0,
|
||||||
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
|
if pre_send_handler:
|
||||||
|
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||||
# if prompt.strip() == "":
|
# if prompt.strip() == "":
|
||||||
# return prompt
|
# return prompt
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||||
@@ -172,6 +175,7 @@ class Agent:
|
|||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_concurrent: int | None = None, # 新增参数,默认并发数为5
|
max_concurrent: int | None = None, # 新增参数,默认并发数为5
|
||||||
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None
|
error_result_handler: ErrorResultHandlerType = None
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
@@ -179,7 +183,7 @@ class Agent:
|
|||||||
total = len(prompts)
|
total = len(prompts)
|
||||||
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
||||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||||
self.total_error_counter.max_errors_count=len(prompts) // MAX_REQUESTS_PER_ERROR #允许多少个异常
|
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常
|
||||||
count = 0
|
count = 0
|
||||||
semaphore = asyncio.Semaphore(max_concurrent)
|
semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
tasks = []
|
tasks = []
|
||||||
@@ -194,6 +198,7 @@ class Agent:
|
|||||||
client=client,
|
client=client,
|
||||||
prompt=p_text,
|
prompt=p_text,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
)
|
)
|
||||||
@@ -210,9 +215,11 @@ class Agent:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||||
result_handler=None, error_result_handler=None) -> Any:
|
pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
|
if pre_send_handler:
|
||||||
|
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||||
# if prompt.strip() == "":
|
# if prompt.strip() == "":
|
||||||
# return prompt
|
# return prompt
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||||
@@ -259,6 +266,7 @@ class Agent:
|
|||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
|
pre_send_handler:PreSendHandlerType=None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None
|
error_result_handler: ErrorResultHandlerType = None
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
@@ -271,6 +279,7 @@ class Agent:
|
|||||||
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
|
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
|
||||||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||||
counters = itertools.repeat(counter, len(prompts))
|
counters = itertools.repeat(counter, len(prompts))
|
||||||
|
pre_send_handlers=itertools.repeat(pre_send_handler,len(prompts))
|
||||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||||
output_list = []
|
output_list = []
|
||||||
@@ -279,6 +288,7 @@ class Agent:
|
|||||||
clients = itertools.repeat(client, len(prompts))
|
clients = itertools.repeat(client, len(prompts))
|
||||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||||
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
||||||
|
pre_send_handlers,
|
||||||
result_handlers,
|
result_handlers,
|
||||||
error_result_handlers)
|
error_result_handlers)
|
||||||
output_list = list(results_iterator)
|
output_list = list(results_iterator)
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from .agent import Agent, AgentConfig
|
from .agent import Agent, AgentConfig
|
||||||
|
from ..glossary.glossary import Glossary
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MDTranslateAgentConfig(AgentConfig):
|
class MDTranslateAgentConfig(AgentConfig):
|
||||||
to_lang:str
|
to_lang: str
|
||||||
custom_prompt:str|None=None
|
custom_prompt: str | None = None
|
||||||
|
glossary_dict: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class MDTranslateAgent(Agent):
|
class MDTranslateAgent(Agent):
|
||||||
def __init__(self,config:MDTranslateAgentConfig):
|
def __init__(self, config: MDTranslateAgentConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
# Role
|
# Role
|
||||||
@@ -48,5 +52,17 @@ Output:
|
|||||||
这个方程是 $E=mc^2$。这很有名。
|
这个方程是 $E=mc^2$。这很有名。
|
||||||
$$1+1=2$$
|
$$1+1=2$$
|
||||||
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
||||||
|
self.custom_prompt = config.custom_prompt
|
||||||
if config.custom_prompt:
|
if config.custom_prompt:
|
||||||
self.system_prompt += "\n# **Important rules or background** \n" + config.custom_prompt + '\n'
|
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||||
|
self.glossary_dict = config.glossary_dict
|
||||||
|
|
||||||
|
def _pre_send_handler(self, system_prompt, prompt):
|
||||||
|
if self.glossary_dict:
|
||||||
|
glossary = Glossary(glossary_dict=self.glossary_dict)
|
||||||
|
system_prompt += glossary.append_system_prompt(prompt)
|
||||||
|
return system_prompt, prompt
|
||||||
|
def send_chunks(self, prompts: list[str]):
|
||||||
|
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
|
||||||
|
async def send_chunks_async(self, prompts: list[str]):
|
||||||
|
return await super().send_prompts_async(prompts=prompts,pre_send_handler=self._pre_send_handler)
|
||||||
@@ -7,6 +7,7 @@ from logging import Logger
|
|||||||
from json_repair import json_repair
|
from json_repair import json_repair
|
||||||
|
|
||||||
from docutranslate.agents import AgentConfig, Agent
|
from docutranslate.agents import AgentConfig, Agent
|
||||||
|
from docutranslate.glossary.glossary import Glossary
|
||||||
from docutranslate.utils.json_utils import segments2json_chunks
|
from docutranslate.utils.json_utils import segments2json_chunks
|
||||||
|
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ from docutranslate.utils.json_utils import segments2json_chunks
|
|||||||
class SegmentsTranslateAgentConfig(AgentConfig):
|
class SegmentsTranslateAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: str | None = None
|
custom_prompt: str | None = None
|
||||||
|
glossary_dict: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SegmentsTranslateAgent(Agent):
|
class SegmentsTranslateAgent(Agent):
|
||||||
@@ -43,13 +45,21 @@ Output
|
|||||||
{r'{"0":"你好","1":"苹果","2":true,"3":"错误"}'}
|
{r'{"0":"你好","1":"苹果","2":true,"3":"错误"}'}
|
||||||
Warning: Never wrap the entire JSON object in quotes to make it a single string. Never wrap the JSON text in ```.
|
Warning: Never wrap the entire JSON object in quotes to make it a single string. Never wrap the JSON text in ```.
|
||||||
"""
|
"""
|
||||||
|
self.custom_prompt = config.custom_prompt
|
||||||
if config.custom_prompt:
|
if config.custom_prompt:
|
||||||
self.system_prompt += "\n# **Important rules or background** for segments for translation \n" + config.custom_prompt + '\n'
|
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||||
|
self.glossary_dict = config.glossary_dict
|
||||||
|
|
||||||
|
def _pre_send_handler(self, system_prompt, prompt):
|
||||||
|
if self.glossary_dict:
|
||||||
|
glossary = Glossary(glossary_dict=self.glossary_dict)
|
||||||
|
system_prompt += glossary.append_system_prompt(prompt)
|
||||||
|
return system_prompt, prompt
|
||||||
|
|
||||||
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||||
try:
|
try:
|
||||||
result = json_repair.loads(result)
|
result = json_repair.loads(result)
|
||||||
if not isinstance(result,dict):
|
if not isinstance(result, dict):
|
||||||
raise ValueError("agent返回结果不是dict的json形式")
|
raise ValueError("agent返回结果不是dict的json形式")
|
||||||
except:
|
except:
|
||||||
logger.error("结果不能正确解析")
|
logger.error("结果不能正确解析")
|
||||||
@@ -66,7 +76,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
|||||||
def send_segments(self, segments: list[str], chunk_size: int):
|
def send_segments(self, segments: list[str], chunk_size: int):
|
||||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||||
translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
|
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||||
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
indexed_translated = indexed_originals.copy()
|
indexed_translated = indexed_originals.copy()
|
||||||
for chunk in translated_chunks:
|
for chunk in translated_chunks:
|
||||||
@@ -102,7 +113,8 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
|||||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||||
chunk_size)
|
chunk_size)
|
||||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||||
translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
|
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||||
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
indexed_translated = indexed_originals.copy()
|
indexed_translated = indexed_originals.copy()
|
||||||
for chunk in translated_chunks:
|
for chunk in translated_chunks:
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from docutranslate.agents import AgentConfig, Agent
|
from docutranslate.agents import AgentConfig, Agent
|
||||||
|
from docutranslate.glossary.glossary import Glossary
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TXTTranslateAgentConfig(AgentConfig):
|
class TXTTranslateAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: str | None = None
|
custom_prompt: str | None = None
|
||||||
|
glossary_dict: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class TXTTranslateAgent(Agent):
|
class TXTTranslateAgent(Agent):
|
||||||
@@ -30,5 +32,17 @@ Target language: {config.to_lang}
|
|||||||
# Output
|
# Output
|
||||||
The translated txt text as plain text.
|
The translated txt text as plain text.
|
||||||
"""
|
"""
|
||||||
|
self.custom_prompt = config.custom_prompt
|
||||||
if config.custom_prompt:
|
if config.custom_prompt:
|
||||||
self.system_prompt += "\n# **Important rules or background**\n" + config.custom_prompt + '\n'
|
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||||
|
self.glossary_dict = config.glossary_dict
|
||||||
|
|
||||||
|
def _pre_send_handler(self, system_prompt, prompt):
|
||||||
|
if self.glossary_dict:
|
||||||
|
glossary = Glossary(glossary_dict=self.glossary_dict)
|
||||||
|
system_prompt += glossary.append_system_prompt(prompt)
|
||||||
|
return system_prompt, prompt
|
||||||
|
def send_chunks(self, prompts: list[str]):
|
||||||
|
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler)
|
||||||
|
async def send_chunks_async(self, prompts: list[str]):
|
||||||
|
return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler)
|
||||||
|
|||||||
0
docutranslate/glossary/__init__.py
Normal file
0
docutranslate/glossary/__init__.py
Normal file
15
docutranslate/glossary/glossary.py
Normal file
15
docutranslate/glossary/glossary.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
class Glossary:
|
||||||
|
def __init__(self,glossary_dict:dict[str:str]|None=None):
|
||||||
|
self.glossary_dict=glossary_dict
|
||||||
|
|
||||||
|
def update(self,update_dict:dict[str:str]):
|
||||||
|
for src,dst in update_dict.items():
|
||||||
|
if src not in self.glossary_dict:
|
||||||
|
self.glossary_dict[src]=dst
|
||||||
|
|
||||||
|
def append_system_prompt(self,text:str):
|
||||||
|
prompt="\n以下为参考术语表:\n"
|
||||||
|
for src,dst in self.glossary_dict:
|
||||||
|
if src in text:
|
||||||
|
prompt+=f"{src}=>{dst}\n"
|
||||||
|
prompt+="术语表结束\n"
|
||||||
Reference in New Issue
Block a user