检查agent的result是否满足segments要求

This commit is contained in:
xunbu
2025-08-21 22:36:05 +08:00
parent 4981684e4f
commit c42f02fe08
2 changed files with 77 additions and 37 deletions

View File

@@ -1,10 +1,11 @@
import asyncio import asyncio
import itertools
import logging import logging
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import Literal from typing import Literal, Callable, Any
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@@ -67,6 +68,9 @@ class PromptsCounter:
TIMEOUT = 600 TIMEOUT = 600
ResultHandlerType = Callable[[str, str, logging.Logger], str]
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
class Agent: class Agent:
_think_factory = { _think_factory = {
@@ -129,7 +133,9 @@ class Agent:
self._add_thinking_mode(data) self._add_thinking_mode(data)
return headers, data return headers, data
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str: async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
if prompt.strip() == "": if prompt.strip() == "":
@@ -145,12 +151,12 @@ class Agent:
) )
response.raise_for_status() response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"] result = response.json()["choices"][0]["message"]["content"]
return result return result if result_handler is None else result_handler(result, prompt, self.logger)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}") print(f"prompt:\n{prompt}")
self.total_error_counter.add() self.total_error_counter.add()
return prompt return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
except httpx.RequestError as e: except httpx.RequestError as e:
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}") self.logger.warning(f"AI请求连接错误 (async): {repr(e)}")
except (KeyError, IndexError) as e: except (KeyError, IndexError) as e:
@@ -158,20 +164,23 @@ class Agent:
# 如果没有正常获取结果则重试 # 如果没有正常获取结果则重试
if retry and retry_count < MAX_RETRY_COUNT: if retry and retry_count < MAX_RETRY_COUNT:
if self.total_error_counter.add(): if self.total_error_counter.add():
return prompt return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
self.logger.info(f"正在重试,重试次数{retry_count}") self.logger.info(f"正在重试,重试次数{retry_count}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1) return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
result_handler=result_handler)
else: else:
self.logger.error(f"达到重试次数上限") self.logger.error(f"达到重试次数上限")
return prompt return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
async def send_prompts_async( async def send_prompts_async(
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
max_concurrent: int | None = None # 新增参数默认并发数为5 max_concurrent: int | None = None, # 新增参数默认并发数为5
) -> list[str]: result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
) -> list[Any]:
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
total = len(prompts) total = len(prompts)
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
@@ -186,6 +195,8 @@ class Agent:
result = await self.send_async( result = await self.send_async(
prompt=p_text, prompt=p_text,
system_prompt=system_prompt, system_prompt=system_prompt,
result_handler=result_handler,
error_result_handler=error_result_handler,
) )
nonlocal count nonlocal count
count += 1 count += 1
@@ -199,7 +210,8 @@ class Agent:
results = await asyncio.gather(*tasks, return_exceptions=False) results = await asyncio.gather(*tasks, return_exceptions=False)
return results return results
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str: def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
result_handler=None, error_result_handler=None) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
if prompt.strip() == "": if prompt.strip() == "":
@@ -214,12 +226,12 @@ class Agent:
) )
response.raise_for_status() response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"] result = response.json()["choices"][0]["message"]["content"]
return result return result if result_handler is None else result_handler(result, prompt, self.logger)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}") print(f"prompt:\n{prompt}")
self.total_error_counter.add() self.total_error_counter.add()
return prompt return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
except httpx.RequestError as e: except httpx.RequestError as e:
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}") self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
except (KeyError, IndexError) as e: except (KeyError, IndexError) as e:
@@ -227,16 +239,19 @@ class Agent:
# 如果没有正常获取结果则重试 # 如果没有正常获取结果则重试
if retry and retry_count < MAX_RETRY_COUNT: if retry and retry_count < MAX_RETRY_COUNT:
if self.total_error_counter.add(): if self.total_error_counter.add():
return prompt return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
self.logger.info(f"正在重试,重试次数{retry_count}") self.logger.info(f"正在重试,重试次数{retry_count}")
time.sleep(0.5) time.sleep(0.5)
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1) return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
result_handler=result_handler)
else: else:
self.logger.error(f"达到重试次数上限") self.logger.error(f"达到重试次数上限")
return prompt return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter) -> str: def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler,
result = self.send(prompt, system_prompt) error_result_handler) -> Any:
result = self.send(prompt, system_prompt, result_handler=result_handler,
error_result_handler=error_result_handler)
count.add() count.add()
return result return result
@@ -244,14 +259,23 @@ class Agent:
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
) -> list[str]: result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
) -> list[Any]:
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}") self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
system_prompts = [system_prompt] * len(prompts)
counts = [PromptsCounter(len(prompts), self.logger)] * len(prompts) # 创建单个计数器实例
counter = PromptsCounter(len(prompts), self.logger)
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
system_prompts = itertools.repeat(system_prompt, len(prompts))
counters = itertools.repeat(counter, len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts))
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
output_list = [] output_list = []
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counts) results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers)
output_list = list(results_iterator) output_list = list(results_iterator)
return output_list return output_list

View File

@@ -2,11 +2,12 @@ import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from json import JSONDecodeError from json import JSONDecodeError
from logging import Logger
from json_repair import json_repair from json_repair import json_repair
from docutranslate.agents import AgentConfig, Agent from docutranslate.agents import AgentConfig, Agent
from docutranslate.utils.json_utils import segments2json_chunks from docutranslate.utils.json_utils import segments2json_chunks
@dataclass @dataclass
@@ -43,23 +44,38 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
if config.custom_prompt: if config.custom_prompt:
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n' self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
try:
result = json_repair.loads(result)
except:
logger.error("结果不能正确解析")
return self._error_result_handler(origin_prompt, logger)
return result
def _error_result_handler(self, origin_prompt: str, logger: Logger):
try:
return json_repair.loads(origin_prompt)
except:
logger.error("prompt不是json格式")
return origin_prompt
def send_segments(self, segments: list[str], chunk_size: int): def send_segments(self, segments: list[str], chunk_size: int):
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts) translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy() indexed_translated = indexed_originals.copy()
for chunk_str in translated_chunks: for chunk in translated_chunks:
try: try:
translated_part = json_repair.loads(chunk_str) for key, val in chunk.items():
for key, val in translated_part.items():
if key in indexed_translated: if key in indexed_translated:
indexed_translated[key] = val indexed_translated[key] = val
except JSONDecodeError as e: except JSONDecodeError as e:
self.logger.info(f"json解析错误解析文本:{chunk_str},错误:{e.__repr__()}") self.logger.info(f"json解析错误解析文本:{chunk},错误:{e.__repr__()}")
except ValueError as e: except ValueError as e:
self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}") self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}")
except Exception as e: except Exception as e:
self.logger.info(f"send_segments错误:{e.__repr__()}") self.logger.info(f"send_segments发生错误:{e.__repr__()}")
# 初始化结果列表 # 初始化结果列表
result = [] result = []
@@ -79,28 +95,28 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
# todo:增加协程粒度 # todo:增加协程粒度
async def send_segments_async(self, segments: list[str], chunk_size: int): async def send_segments_async(self, segments: list[str], chunk_size: int):
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks,segments, chunk_size) indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts) translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy() indexed_translated = indexed_originals.copy()
for chunk_str in translated_chunks: for chunk in translated_chunks:
try: try:
translated_part = json_repair.loads(chunk_str) for key, val in chunk.items():
for key, val in translated_part.items():
if key in indexed_translated: if key in indexed_translated:
indexed_translated[key] = val indexed_translated[key] = val
except JSONDecodeError as e: except JSONDecodeError as e:
self.logger.info(f"json解析错误解析文本:{chunk_str},错误:{e.__repr__()}") self.logger.info(f"json解析错误解析文本:{chunk},错误:{e.__repr__()}")
except ValueError as e: except ValueError as e:
self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}") self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}")
except Exception as e: except Exception as e:
self.logger.info(f"send_segments错误:{e.__repr__()}") self.logger.info(f"send_segments发生错误:{e.__repr__()}")
# 初始化结果列表 # 初始化结果列表
result = [] result = []
last_end = 0 last_end = 0
ls=list(indexed_translated.values()) ls = list(indexed_translated.values())
for start, end in merged_indices_list: for start, end in merged_indices_list:
# 添加未处理的部分 # 添加未处理的部分
result.extend(ls[last_end:start]) result.extend(ls[last_end:start])