优化agent异常重试机制与提示
This commit is contained in:
@@ -23,6 +23,14 @@ MAX_REQUESTS_PER_ERROR = 20
|
||||
ThinkingMode = Literal["enable", "disable", "default"]
|
||||
|
||||
|
||||
class PartialTranslationError(ValueError):
|
||||
"""一个特殊的异常,用于表示结果不完整但包含了部分成功的数据,以便触发重试。"""
|
||||
|
||||
def __init__(self, message, partial_result: dict):
|
||||
super().__init__(message)
|
||||
self.partial_result = partial_result
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class AgentConfig:
|
||||
logger: logging.Logger
|
||||
@@ -43,12 +51,11 @@ class TotalErrorCounter:
|
||||
self.max_errors_count = max_errors_count
|
||||
|
||||
def add(self):
|
||||
self.lock.acquire()
|
||||
self.count += 1
|
||||
if self.count > self.max_errors_count:
|
||||
self.logger.info(f"错误响应过多")
|
||||
self.lock.release()
|
||||
return self.reach_limit()
|
||||
with self.lock:
|
||||
self.count += 1
|
||||
if self.count > self.max_errors_count:
|
||||
self.logger.info(f"错误响应过多")
|
||||
return self.reach_limit()
|
||||
|
||||
def reach_limit(self):
|
||||
return self.count > self.max_errors_count
|
||||
@@ -63,10 +70,9 @@ class PromptsCounter:
|
||||
self.logger = logger
|
||||
|
||||
def add(self):
|
||||
self.lock.acquire()
|
||||
self.count += 1
|
||||
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
|
||||
self.lock.release()
|
||||
with self.lock:
|
||||
self.count += 1
|
||||
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
|
||||
|
||||
|
||||
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
||||
@@ -129,7 +135,6 @@ class Agent:
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": temperature,
|
||||
@@ -143,14 +148,16 @@ class Agent:
|
||||
retry_count=0,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
best_partial_result: dict | None = None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if pre_send_handler:
|
||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||
# if prompt.strip() == "":
|
||||
# return prompt
|
||||
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
should_retry = False
|
||||
current_partial_result = None
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
@@ -161,35 +168,70 @@ class Agent:
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
if retry_count > 0:
|
||||
self.logger.info(f"重试成功 (第 {retry_count + 1}/{MAX_RETRY_COUNT + 1} 次尝试)。")
|
||||
|
||||
# print(f"result:=============================================================\n{result}\n================\n")
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
|
||||
# 专门捕获部分翻译错误
|
||||
except PartialTranslationError as e:
|
||||
self.logger.error(f"收到部分翻译结果,将尝试重试: {e}")
|
||||
current_partial_result = e.partial_result # 保存这次的部分结果
|
||||
should_retry = True
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
||||
self.logger.error(f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}")
|
||||
print(f"prompt:\n{prompt}")
|
||||
self.total_error_counter.add()
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
should_retry = True
|
||||
except httpx.RequestError as e:
|
||||
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}")
|
||||
except (KeyError, IndexError) as e:
|
||||
raise Exception(f"AI响应格式错误 (async): {repr(e)}")
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"{e.__repr__()}")
|
||||
# 如果没有正常获取结果则重试
|
||||
if retry and retry_count < MAX_RETRY_COUNT:
|
||||
if self.total_error_counter.add():
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||
self.logger.error(f"AI请求连接错误 (async): {repr(e)}")
|
||||
should_retry = True
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
self.logger.error(f"AI响应格式或值错误 (async), 将尝试重试: {repr(e)}")
|
||||
should_retry = True
|
||||
|
||||
# 如果当前捕获到了部分结果,就更新“最佳”结果
|
||||
if current_partial_result:
|
||||
best_partial_result = current_partial_result
|
||||
|
||||
if should_retry and retry and retry_count < MAX_RETRY_COUNT:
|
||||
if retry_count == 0:
|
||||
if self.total_error_counter.add():
|
||||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||||
# 如果有部分结果,优先返回部分结果
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
elif self.total_error_counter.reach_limit():
|
||||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
|
||||
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
||||
await asyncio.sleep(0.5)
|
||||
# 将“最佳”结果传递给下一次递归调用
|
||||
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
result_handler=result_handler)
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
best_partial_result=best_partial_result)
|
||||
else:
|
||||
self.logger.error(f"达到重试次数上限")
|
||||
if should_retry:
|
||||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||||
|
||||
# 在最终失败时,检查是否有可用的部分结果
|
||||
if best_partial_result:
|
||||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||||
return best_partial_result
|
||||
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
|
||||
async def send_prompts_async(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
max_concurrent: int | None = None, # 新增参数,默认并发数为5
|
||||
max_concurrent: int | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
@@ -197,19 +239,18 @@ class Agent:
|
||||
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
|
||||
total = len(prompts)
|
||||
self.logger.info(
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}")
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}")
|
||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
count = 0
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
tasks = []
|
||||
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
|
||||
# 辅助协程,用于包装 self.send_async 并使用信号量
|
||||
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client:
|
||||
async def send_with_semaphore(p_text: str):
|
||||
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
||||
async with semaphore:
|
||||
result = await self.send_async(
|
||||
client=client,
|
||||
prompt=p_text,
|
||||
@@ -231,14 +272,17 @@ class Agent:
|
||||
return results
|
||||
|
||||
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any:
|
||||
pre_send_handler=None, result_handler=None, error_result_handler=None,
|
||||
best_partial_result: dict | None = None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if pre_send_handler:
|
||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||
# if prompt.strip() == "":
|
||||
# return prompt
|
||||
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
should_retry = False
|
||||
current_partial_result = None
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
@@ -248,28 +292,63 @@ class Agent:
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
if retry_count > 0:
|
||||
self.logger.info(f"重试成功 (第 {retry_count + 1}/{MAX_RETRY_COUNT + 1} 次尝试)。")
|
||||
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
|
||||
# --- MODIFICATION START ---
|
||||
except PartialTranslationError as e:
|
||||
self.logger.error(f"收到部分翻译结果,将尝试重试: {e}")
|
||||
current_partial_result = e.partial_result
|
||||
should_retry = True
|
||||
# --- MODIFICATION END ---
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.warning(f"AI请求错误 (sync): {e.response.status_code} - {e.response.text}")
|
||||
self.logger.error(f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}")
|
||||
print(f"prompt:\n{prompt}")
|
||||
self.total_error_counter.add()
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
should_retry = True
|
||||
except httpx.RequestError as e:
|
||||
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
||||
except (KeyError, IndexError) as e:
|
||||
raise Exception(f"AI响应格式错误 (sync): {repr(e)}")
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"{e.__repr__()}")
|
||||
# 如果没有正常获取结果则重试
|
||||
if retry and retry_count < MAX_RETRY_COUNT:
|
||||
if self.total_error_counter.add():
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||
self.logger.error(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
||||
should_retry = True
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
self.logger.error(f"AI响应格式或值错误 (sync), 将尝试重试: {repr(e)}")
|
||||
should_retry = True
|
||||
|
||||
# --- MODIFICATION START ---
|
||||
if current_partial_result:
|
||||
best_partial_result = current_partial_result
|
||||
# --- MODIFICATION END ---
|
||||
|
||||
if should_retry and retry and retry_count < MAX_RETRY_COUNT:
|
||||
if retry_count == 0:
|
||||
if self.total_error_counter.add():
|
||||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
elif self.total_error_counter.reach_limit():
|
||||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
|
||||
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
||||
time.sleep(0.5)
|
||||
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
result_handler=result_handler)
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
best_partial_result=best_partial_result)
|
||||
else:
|
||||
self.logger.error(f"达到重试次数上限")
|
||||
if should_retry:
|
||||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||||
|
||||
# --- MODIFICATION START ---
|
||||
if best_partial_result:
|
||||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||||
return best_partial_result
|
||||
# --- MODIFICATION END ---
|
||||
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
|
||||
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
|
||||
@@ -293,17 +372,15 @@ class Agent:
|
||||
self.logger.info(
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}")
|
||||
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常
|
||||
# 创建单个计数器实例
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
counter = PromptsCounter(len(prompts), self.logger)
|
||||
|
||||
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
|
||||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||
counters = itertools.repeat(counter, len(prompts))
|
||||
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
|
||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||
output_list = []
|
||||
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client:
|
||||
clients = itertools.repeat(client, len(prompts))
|
||||
|
||||
@@ -50,24 +50,27 @@ The output format should be plain JSON text in a list format
|
||||
|
||||
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||
if result == "":
|
||||
if origin_prompt.strip()!="":
|
||||
logger.error("result为空值但原文不为空")
|
||||
raise ValueError("result为空值但原文不为空")
|
||||
return []
|
||||
try:
|
||||
result = json_repair.loads(result)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError("GlossaryAgent返回结果不是list的json形式")
|
||||
except:
|
||||
logger.error("结果不能正确解析")
|
||||
return self._error_result_handler(origin_prompt, logger)
|
||||
return result
|
||||
repaired_result = json_repair.loads(result)
|
||||
if not isinstance(repaired_result, list):
|
||||
raise ValueError(f"GlossaryAgent返回结果不是list的json形式, result: {result}")
|
||||
return repaired_result
|
||||
except (RuntimeError, JSONDecodeError) as e:
|
||||
# 将解析错误包装成 ValueError 以便被 send 方法捕获并重试
|
||||
raise ValueError(f"结果不能正确解析: {e.__repr__()}")
|
||||
|
||||
def _error_result_handler(self, origin_prompt: str, logger: Logger):
|
||||
if origin_prompt == "":
|
||||
return []
|
||||
try:
|
||||
return json_repair.loads(origin_prompt)
|
||||
except:
|
||||
logger.error("prompt不是json格式")
|
||||
return origin_prompt
|
||||
except (RuntimeError, JSONDecodeError):
|
||||
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}")
|
||||
return [] # 如果原始prompt也无效,返回空列表
|
||||
|
||||
def send_segments(self, segments: list[str], chunk_size: int):
|
||||
self.logger.info(f"开始提取术语表,to_lang:{self.to_lang}")
|
||||
@@ -78,14 +81,17 @@ The output format should be plain JSON text in a list format
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
for chunk in translated_chunks:
|
||||
chunk: list[dict[str, str]]
|
||||
try:
|
||||
glossary_dict = {d["src"]: d["dst"] for d in chunk}
|
||||
if not isinstance(chunk, list):
|
||||
self.logger.error(f"接收到的chunk不是有效的列表,已跳过: {chunk}")
|
||||
continue
|
||||
glossary_dict = {d["src"]: d["dst"] for d in chunk if isinstance(d, dict) and "src" in d and "dst" in d}
|
||||
result = glossary_dict | result
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||
except (TypeError, KeyError) as e:
|
||||
self.logger.error(f"处理glossary chunk时发生键或类型错误,已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
|
||||
except Exception as e:
|
||||
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||
self.logger.error(f"处理glossary chunk时发生未知错误: {e.__repr__()}")
|
||||
|
||||
self.logger.info("术语表提取完成")
|
||||
return result
|
||||
|
||||
@@ -99,14 +105,16 @@ The output format should be plain JSON text in a list format
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
for chunk in translated_chunks:
|
||||
chunk: list[dict[str, str]]
|
||||
try:
|
||||
glossary_dict = {d["src"]: d["dst"] for d in chunk}
|
||||
if not isinstance(chunk, list):
|
||||
self.logger.error(f"接收到的chunk不是有效的列表,已跳过: {chunk}")
|
||||
continue
|
||||
glossary_dict = {d["src"]: d["dst"] for d in chunk if isinstance(d, dict) and "src" in d and "dst" in d}
|
||||
result = result | glossary_dict
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||
except (TypeError, KeyError) as e:
|
||||
self.logger.error(f"处理glossary chunk时发生键或类型错误,已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
|
||||
except Exception as e:
|
||||
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||
# print(f"术语表:\n{result}")
|
||||
self.logger.error(f"处理glossary chunk时发生未知错误: {e.__repr__()}")
|
||||
|
||||
self.logger.info("术语表提取完成")
|
||||
return result
|
||||
@@ -57,7 +57,7 @@ $$1+1=2$$
|
||||
\\((c_0,c_1,c_2^2)\\)是一个坐标。"""
|
||||
self.custom_prompt = config.custom_prompt
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\nEND\n'
|
||||
self.glossary_dict = config.glossary_dict
|
||||
|
||||
def _pre_send_handler(self, system_prompt, prompt):
|
||||
|
||||
@@ -10,6 +10,7 @@ from logging import Logger
|
||||
from json_repair import json_repair
|
||||
|
||||
from docutranslate.agents import AgentConfig, Agent
|
||||
from docutranslate.agents.agent import PartialTranslationError
|
||||
from docutranslate.glossary.glossary import Glossary
|
||||
from docutranslate.utils.json_utils import segments2json_chunks
|
||||
|
||||
@@ -50,7 +51,7 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
||||
"""
|
||||
self.custom_prompt = config.custom_prompt
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n'
|
||||
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\nEND\n'
|
||||
self.glossary_dict = config.glossary_dict
|
||||
|
||||
def _pre_send_handler(self, system_prompt, prompt):
|
||||
@@ -60,94 +61,152 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
||||
return system_prompt, prompt
|
||||
|
||||
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||
"""
|
||||
处理成功的API响应。
|
||||
- 如果键完全匹配,返回翻译结果。
|
||||
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
|
||||
- 其他错误(如JSON解析失败、模型偷懒)则抛出普通 ValueError 触发重试。
|
||||
"""
|
||||
if result == "":
|
||||
if origin_prompt.strip() != "":
|
||||
logger.error("result为空值但原文不为空")
|
||||
raise ValueError("result为空值但原文不为空")
|
||||
return {}
|
||||
try:
|
||||
result = json_repair.loads(result)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"agent返回结果不是dict的json形式,result:{result}")
|
||||
except RuntimeError as e:
|
||||
raise ValueError(f"结果不能正确解析:{e.__repr__()}")
|
||||
return result
|
||||
original_chunk = json.loads(origin_prompt)
|
||||
repaired_result = json_repair.loads(result)
|
||||
|
||||
if not isinstance(repaired_result, dict):
|
||||
raise ValueError(f"Agent返回结果不是dict的json形式, result: {result}")
|
||||
|
||||
if repaired_result == original_chunk:
|
||||
raise ValueError("翻译结果与原文完全相同,判定为翻译失败,将进行重试。")
|
||||
|
||||
original_keys = set(original_chunk.keys())
|
||||
result_keys = set(repaired_result.keys())
|
||||
|
||||
# 如果键不完全匹配
|
||||
if original_keys != result_keys:
|
||||
# 仍然先构造一个最完整的“部分结果”
|
||||
final_chunk = {}
|
||||
common_keys = original_keys.intersection(result_keys)
|
||||
missing_keys = original_keys - result_keys
|
||||
extra_keys = result_keys - original_keys
|
||||
|
||||
logger.warning(f"翻译结果的键与原文不匹配!将尝试重试。")
|
||||
if missing_keys: logger.warning(f"缺失的键: {missing_keys}")
|
||||
if extra_keys: logger.warning(f"多余的键: {extra_keys}")
|
||||
|
||||
for key in common_keys:
|
||||
final_chunk[key] = str(repaired_result[key])
|
||||
for key in missing_keys:
|
||||
final_chunk[key] = str(original_chunk[key])
|
||||
|
||||
# 抛出自定义异常,将部分结果和错误信息一起传递出去
|
||||
raise PartialTranslationError("键不匹配,触发重试", partial_result=final_chunk)
|
||||
|
||||
# 如果键完全匹配(理想情况),正常返回
|
||||
for key, value in repaired_result.items():
|
||||
repaired_result[key] = str(value)
|
||||
|
||||
return repaired_result
|
||||
|
||||
except (RuntimeError, JSONDecodeError) as e:
|
||||
# 对于JSON解析等硬性错误,继续抛出普通ValueError
|
||||
raise ValueError(f"结果处理失败: {e.__repr__()}")
|
||||
|
||||
def _error_result_handler(self, origin_prompt: str, logger: Logger):
|
||||
"""
|
||||
处理在所有重试后仍然失败的请求。
|
||||
作为备用方案,返回原文内容,并将所有值转换为字符串。
|
||||
"""
|
||||
if origin_prompt == "":
|
||||
return {}
|
||||
try:
|
||||
return json_repair.loads(origin_prompt)
|
||||
except:
|
||||
logger.error("prompt不是json格式")
|
||||
return origin_prompt
|
||||
original_chunk = json.loads(origin_prompt)
|
||||
# 此处逻辑保留,作为最终的兜底方案
|
||||
for key, value in original_chunk.items():
|
||||
original_chunk[key] = f"{value}"
|
||||
return original_chunk
|
||||
except (RuntimeError, JSONDecodeError):
|
||||
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}")
|
||||
# 如果原始prompt本身也无效,返回一个清晰的错误对象
|
||||
return {"error": f"{origin_prompt}"}
|
||||
|
||||
def send_segments(self, segments: list[str], chunk_size: int):
|
||||
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
|
||||
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk in translated_chunks:
|
||||
try:
|
||||
if not isinstance(chunk, dict):
|
||||
self.logger.warning(f"接收到的chunk不是有效的字典,已跳过: {chunk}")
|
||||
continue
|
||||
for key, val in chunk.items():
|
||||
if key in indexed_translated:
|
||||
# 此处不再需要 str(val)
|
||||
indexed_translated[key] = val
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||
except ValueError as e:
|
||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
else:
|
||||
self.logger.warning(f"在结果chunk中发现未知键 '{key}',已忽略。")
|
||||
except (AttributeError, TypeError) as e:
|
||||
self.logger.error(f"处理chunk时发生类型或属性错误,已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
|
||||
except Exception as e:
|
||||
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||
self.logger.error(f"处理chunk时发生未知错误: {e.__repr__()}")
|
||||
|
||||
# 初始化结果列表
|
||||
# 重建最终列表
|
||||
result = []
|
||||
last_end = 0
|
||||
ls = list(indexed_translated.values())
|
||||
for start, end in merged_indices_list:
|
||||
# 添加未处理的部分
|
||||
result.extend(ls[last_end:start])
|
||||
# 合并切片范围内的元素
|
||||
merged_item = "".join(ls[start:end])
|
||||
merged_item = "".join(map(str, ls[start:end]))
|
||||
result.append(merged_item)
|
||||
last_end = end
|
||||
|
||||
# 添加剩余部分
|
||||
result.extend(ls[last_end:])
|
||||
return result
|
||||
|
||||
# todo:增加协程粒度
|
||||
async def send_segments_async(self, segments: list[str], chunk_size: int):
|
||||
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||
chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk in translated_chunks:
|
||||
try:
|
||||
if not isinstance(chunk, dict):
|
||||
self.logger.error(f"接收到的chunk不是有效的字典,已跳过: {chunk}")
|
||||
continue
|
||||
for key, val in chunk.items():
|
||||
if key in indexed_translated:
|
||||
indexed_translated[key] = str(val)
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||
except ValueError as e:
|
||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
# 此处不再需要 str(val),因为 _result_handler 已经处理好了
|
||||
indexed_translated[key] = val
|
||||
else:
|
||||
self.logger.warning(f"在结果chunk中发现未知键 '{key}',已忽略。")
|
||||
except (AttributeError, TypeError) as e:
|
||||
self.logger.error(f"处理chunk时发生类型或属性错误,已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
|
||||
except Exception as e:
|
||||
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||
self.logger.error(f"处理chunk时发生未知错误: {e.__repr__()}")
|
||||
|
||||
# 初始化结果列表
|
||||
# 重建最终列表
|
||||
result = []
|
||||
last_end = 0
|
||||
ls = list(indexed_translated.values())
|
||||
for start, end in merged_indices_list:
|
||||
# 添加未处理的部分
|
||||
result.extend(ls[last_end:start])
|
||||
# 合并切片范围内的元素
|
||||
merged_item = "".join(ls[start:end])
|
||||
merged_item = "".join(map(str, ls[start:end]))
|
||||
result.append(merged_item)
|
||||
last_end = end
|
||||
|
||||
# 添加剩余部分
|
||||
result.extend(ls[last_end:])
|
||||
return result
|
||||
|
||||
|
||||
@@ -12,86 +12,74 @@ def segments2json_chunks(segments: list[str], chunk_size_max: int) -> tuple[dict
|
||||
list[dict[str, str]], list[tuple[int, int]]]:
|
||||
"""
|
||||
将文本段列表(segments)转换为多个JSON块。
|
||||
|
||||
功能描述:
|
||||
1. 每个JSON块的格式为 {"序号0": "文本0", "序号1": "文本1", ...}。
|
||||
2. 每个JSON块经过UTF-8编码后的字节大小不超过 chunk_size_max(若单行文本就超出了chunk_size_max则保留单行文本)。
|
||||
3. 如果单个文本段本身就超过大小限制,它将被自动分割成多个子文本段。
|
||||
4. 返回值是一个元组,包含两个列表:
|
||||
- json_chunks_list: 分块后的JSON字典列表。
|
||||
- merged_indices_list: 一个元组列表,记录了被分割的文本段在新的序号系统中的起始和结束序号。
|
||||
(函数注释不变)
|
||||
"""
|
||||
|
||||
# === 第一部分:预处理,将过长的segment拆分成更小的部分 ===
|
||||
# === 第一部分:预处理 (这部分逻辑可以保持不变) ===
|
||||
new_segments = []
|
||||
merged_indices_list = []
|
||||
|
||||
for segment in segments:
|
||||
# 检查单个segment(作为一个JSON对象的值)是否已超限
|
||||
if get_json_size({len(new_segments): segment}) > chunk_size_max:
|
||||
# 使用一个较长的key来预估,避免key长度变化带来的误差
|
||||
long_key_estimate = str(len(segments) + len(new_segments))
|
||||
if get_json_size({long_key_estimate: segment}) > chunk_size_max:
|
||||
sub_segments = []
|
||||
lines = segment.splitlines(keepends=True)
|
||||
current_sub_segment = ""
|
||||
for line in lines:
|
||||
next_sub_segment = current_sub_segment + line
|
||||
|
||||
# 预估下一个子段的大小
|
||||
# 使用一个临时的key(如0)来模拟
|
||||
if get_json_size({0: next_sub_segment}) > chunk_size_max:
|
||||
|
||||
# 如果 current_sub_segment 不为空,才将其添加
|
||||
# 这可以防止因第一行就超限而添加一个空字符串
|
||||
if get_json_size({long_key_estimate: next_sub_segment}) > chunk_size_max:
|
||||
if current_sub_segment:
|
||||
sub_segments.append(current_sub_segment)
|
||||
|
||||
# 即使单行超限,也必须作为一个独立的子段添加
|
||||
sub_segments.append(line)
|
||||
current_sub_segment = "" # 重置
|
||||
current_sub_segment = ""
|
||||
else:
|
||||
current_sub_segment = next_sub_segment
|
||||
|
||||
# 不要忘记循环结束后剩余的部分
|
||||
if current_sub_segment:
|
||||
sub_segments.append(current_sub_segment)
|
||||
|
||||
# 如果sub_segments为空(例如,原segment为空字符串),则添加一个空字符串以保持一致性
|
||||
if not sub_segments and segment == "":
|
||||
sub_segments.append("")
|
||||
|
||||
start_index = len(new_segments)
|
||||
new_segments.extend(sub_segments)
|
||||
end_index = len(new_segments)
|
||||
# 只有当一个segment被真正分割成多个部分时,才记录
|
||||
if end_index - start_index > 1:
|
||||
merged_indices_list.append((start_index, end_index))
|
||||
else:
|
||||
new_segments.append(segment)
|
||||
|
||||
# === 第二部分:将处理后的 new_segments 组合成 JSON 块 ===
|
||||
# === 第二部分:组合成 JSON 块 (修正部分) ===
|
||||
json_chunks_list = []
|
||||
if not new_segments: # 处理输入为空列表的边缘情况
|
||||
if not new_segments:
|
||||
return {}, [], []
|
||||
|
||||
js={}
|
||||
chunk = {}
|
||||
for key, val in enumerate(new_segments):
|
||||
# 预先构建下一个块的样子来检查大小
|
||||
prospective_chunk = chunk.copy()
|
||||
prospective_chunk[str(key)] = val
|
||||
|
||||
# 检查 prospective_chunk 是否超限,并且当前 chunk 不为空
|
||||
# 如果 chunk 为空,意味着这个 val 本身就超限了,但我们必须接受它,
|
||||
# 因为它已经是最小单位了。这可以防止产生空字典。
|
||||
# 修复bug: 即使chunk为空,如果 prospective_chunk(即单个元素)已超限,
|
||||
# 也应该先提交旧的chunk。
|
||||
if get_json_size(prospective_chunk) > chunk_size_max and chunk:
|
||||
json_chunks_list.append(chunk) # 将旧的、未超限的块存入列表
|
||||
chunk = {str(key): val} # 用当前元素开始一个新的块
|
||||
json_chunks_list.append(chunk)
|
||||
chunk = {str(key): val}
|
||||
else:
|
||||
chunk = prospective_chunk # 未超限,更新块
|
||||
js[str(key)]=val
|
||||
chunk = prospective_chunk
|
||||
|
||||
# 循环结束后,将最后一个块加入列表
|
||||
if chunk:
|
||||
json_chunks_list.append(chunk)
|
||||
js.update(chunk)
|
||||
|
||||
# ==================== 核心修正 ====================
|
||||
# 根据完整的 new_segments 列表构建最终的、完整的 js 字典
|
||||
# 这确保了第一个返回值是完整的
|
||||
js = {str(i): segment for i, segment in enumerate(new_segments)}
|
||||
# ================================================
|
||||
|
||||
return js, json_chunks_list, merged_indices_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user