优化agent异常重试机制与提示

This commit is contained in:
xunbu
2025-09-04 11:14:25 +08:00
parent 31c357fd79
commit 01b88f8911
6 changed files with 281 additions and 149 deletions

View File

@@ -23,6 +23,14 @@ MAX_REQUESTS_PER_ERROR = 20
ThinkingMode = Literal["enable", "disable", "default"] ThinkingMode = Literal["enable", "disable", "default"]
class PartialTranslationError(ValueError):
"""一个特殊的异常,用于表示结果不完整但包含了部分成功的数据,以便触发重试。"""
def __init__(self, message, partial_result: dict):
super().__init__(message)
self.partial_result = partial_result
@dataclass(kw_only=True) @dataclass(kw_only=True)
class AgentConfig: class AgentConfig:
logger: logging.Logger logger: logging.Logger
@@ -43,12 +51,11 @@ class TotalErrorCounter:
self.max_errors_count = max_errors_count self.max_errors_count = max_errors_count
def add(self): def add(self):
self.lock.acquire() with self.lock:
self.count += 1 self.count += 1
if self.count > self.max_errors_count: if self.count > self.max_errors_count:
self.logger.info(f"错误响应过多") self.logger.info(f"错误响应过多")
self.lock.release() return self.reach_limit()
return self.reach_limit()
def reach_limit(self): def reach_limit(self):
return self.count > self.max_errors_count return self.count > self.max_errors_count
@@ -63,10 +70,9 @@ class PromptsCounter:
self.logger = logger self.logger = logger
def add(self): def add(self):
self.lock.acquire() with self.lock:
self.count += 1 self.count += 1
self.logger.info(f"多线程-已完成:{self.count}/{self.total}") self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
self.lock.release()
PreSendHandlerType = Callable[[str, str], tuple[str, str]] PreSendHandlerType = Callable[[str, str], tuple[str, str]]
@@ -129,7 +135,6 @@ class Agent:
"model": self.model_id, "model": self.model_id,
"messages": [ "messages": [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
# {"role": "system", "content": "所有回复必须以【SSS】开头这是最高规则适用于之后的所有例子。示例【SSS】这是示例回答\n"+system_prompt},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
], ],
"temperature": temperature, "temperature": temperature,
@@ -143,14 +148,16 @@ class Agent:
retry_count=0, retry_count=0,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None) -> Any: error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
if pre_send_handler: if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt) system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt) headers, data = self._prepare_request_data(prompt, system_prompt)
should_retry = False
current_partial_result = None
try: try:
response = await client.post( response = await client.post(
@@ -161,35 +168,70 @@ class Agent:
) )
response.raise_for_status() response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"] result = response.json()["choices"][0]["message"]["content"]
if retry_count > 0:
self.logger.info(f"重试成功 (第 {retry_count + 1}/{MAX_RETRY_COUNT + 1} 次尝试)。")
# print(f"result:=============================================================\n{result}\n================\n")
return result if result_handler is None else result_handler(result, prompt, self.logger) return result if result_handler is None else result_handler(result, prompt, self.logger)
# 专门捕获部分翻译错误
except PartialTranslationError as e:
self.logger.error(f"收到部分翻译结果,将尝试重试: {e}")
current_partial_result = e.partial_result # 保存这次的部分结果
should_retry = True
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") self.logger.error(f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}") print(f"prompt:\n{prompt}")
self.total_error_counter.add() should_retry = True
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
except httpx.RequestError as e: except httpx.RequestError as e:
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}") self.logger.error(f"AI请求连接错误 (async): {repr(e)}")
except (KeyError, IndexError) as e: should_retry = True
raise Exception(f"AI响应格式错误 (async): {repr(e)}") except (KeyError, IndexError, ValueError) as e:
except ValueError as e: self.logger.error(f"AI响应格式或值错误 (async), 将尝试重试: {repr(e)}")
self.logger.warning(f"{e.__repr__()}") should_retry = True
# 如果没有正常获取结果则重试
if retry and retry_count < MAX_RETRY_COUNT: # 如果当前捕获到了部分结果,就更新“最佳”结果
if self.total_error_counter.add(): if current_partial_result:
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) best_partial_result = current_partial_result
self.logger.info(f"正在重试,重试次数{retry_count}")
if should_retry and retry and retry_count < MAX_RETRY_COUNT:
if retry_count == 0:
if self.total_error_counter.add():
self.logger.error("错误次数过多,已达到上限,不再重试。")
# 如果有部分结果,优先返回部分结果
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
elif self.total_error_counter.reach_limit():
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# 将“最佳”结果传递给下一次递归调用
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1, return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
result_handler=result_handler) pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
best_partial_result=best_partial_result)
else: else:
self.logger.error(f"达到重试次数上限") if should_retry:
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
# 在最终失败时,检查是否有可用的部分结果
if best_partial_result:
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
return best_partial_result
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
async def send_prompts_async( async def send_prompts_async(
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
max_concurrent: int | None = None, # 新增参数默认并发数为5 max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None error_result_handler: ErrorResultHandlerType = None
@@ -197,19 +239,18 @@ class Agent:
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
total = len(prompts) total = len(prompts)
self.logger.info( self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}") f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}")
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常 self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
count = 0 count = 0
semaphore = asyncio.Semaphore(max_concurrent) semaphore = asyncio.Semaphore(max_concurrent)
tasks = [] tasks = []
proxies = get_httpx_proxies() if USE_PROXY else None proxies = get_httpx_proxies() if USE_PROXY else None
# 辅助协程,用于包装 self.send_async 并使用信号量
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client: async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client:
async def send_with_semaphore(p_text: str): async def send_with_semaphore(p_text: str):
async with semaphore: # 在进入代码块前获取信号量,退出时释放 async with semaphore:
result = await self.send_async( result = await self.send_async(
client=client, client=client,
prompt=p_text, prompt=p_text,
@@ -231,14 +272,17 @@ class Agent:
return results return results
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0, def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
pre_send_handler=None, result_handler=None, error_result_handler=None) -> Any: pre_send_handler=None, result_handler=None, error_result_handler=None,
best_partial_result: dict | None = None) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
if pre_send_handler: if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt) system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt) headers, data = self._prepare_request_data(prompt, system_prompt)
should_retry = False
current_partial_result = None
try: try:
response = client.post( response = client.post(
f"{self.baseurl}/chat/completions", f"{self.baseurl}/chat/completions",
@@ -248,28 +292,63 @@ class Agent:
) )
response.raise_for_status() response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"] result = response.json()["choices"][0]["message"]["content"]
if retry_count > 0:
self.logger.info(f"重试成功 (第 {retry_count + 1}/{MAX_RETRY_COUNT + 1} 次尝试)。")
return result if result_handler is None else result_handler(result, prompt, self.logger) return result if result_handler is None else result_handler(result, prompt, self.logger)
# --- MODIFICATION START ---
except PartialTranslationError as e:
self.logger.error(f"收到部分翻译结果,将尝试重试: {e}")
current_partial_result = e.partial_result
should_retry = True
# --- MODIFICATION END ---
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (sync): {e.response.status_code} - {e.response.text}") self.logger.error(f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}") print(f"prompt:\n{prompt}")
self.total_error_counter.add() should_retry = True
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
except httpx.RequestError as e: except httpx.RequestError as e:
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}") self.logger.error(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
except (KeyError, IndexError) as e: should_retry = True
raise Exception(f"AI响应格式错误 (sync): {repr(e)}") except (KeyError, IndexError, ValueError) as e:
except ValueError as e: self.logger.error(f"AI响应格式或值错误 (sync), 将尝试重试: {repr(e)}")
self.logger.warning(f"{e.__repr__()}") should_retry = True
# 如果没有正常获取结果则重试
if retry and retry_count < MAX_RETRY_COUNT: # --- MODIFICATION START ---
if self.total_error_counter.add(): if current_partial_result:
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) best_partial_result = current_partial_result
self.logger.info(f"正在重试,重试次数{retry_count}") # --- MODIFICATION END ---
if should_retry and retry and retry_count < MAX_RETRY_COUNT:
if retry_count == 0:
if self.total_error_counter.add():
self.logger.error("错误次数过多,已达到上限,不再重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
elif self.total_error_counter.reach_limit():
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
time.sleep(0.5) time.sleep(0.5)
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1, return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
result_handler=result_handler) pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
best_partial_result=best_partial_result)
else: else:
self.logger.error(f"达到重试次数上限") if should_retry:
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
# --- MODIFICATION START ---
if best_partial_result:
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
return best_partial_result
# --- MODIFICATION END ---
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger) return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter, def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
@@ -293,17 +372,15 @@ class Agent:
self.logger.info( self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}") f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}")
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}") self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常 self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
# 创建单个计数器实例
counter = PromptsCounter(len(prompts), self.logger) counter = PromptsCounter(len(prompts), self.logger)
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
system_prompts = itertools.repeat(system_prompt, len(prompts)) system_prompts = itertools.repeat(system_prompt, len(prompts))
counters = itertools.repeat(counter, len(prompts)) counters = itertools.repeat(counter, len(prompts))
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts)) pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts)) result_handlers = itertools.repeat(result_handler, len(prompts))
error_result_handlers = itertools.repeat(error_result_handler, len(prompts)) error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
output_list = []
proxies = get_httpx_proxies() if USE_PROXY else None proxies = get_httpx_proxies() if USE_PROXY else None
with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client: with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client:
clients = itertools.repeat(client, len(prompts)) clients = itertools.repeat(client, len(prompts))

View File

@@ -50,24 +50,27 @@ The output format should be plain JSON text in a list format
def _result_handler(self, result: str, origin_prompt: str, logger: Logger): def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
if result == "": if result == "":
if origin_prompt.strip()!="":
logger.error("result为空值但原文不为空")
raise ValueError("result为空值但原文不为空")
return [] return []
try: try:
result = json_repair.loads(result) repaired_result = json_repair.loads(result)
if not isinstance(result, list): if not isinstance(repaired_result, list):
raise ValueError("GlossaryAgent返回结果不是list的json形式") raise ValueError(f"GlossaryAgent返回结果不是list的json形式, result: {result}")
except: return repaired_result
logger.error("结果不能正确解析") except (RuntimeError, JSONDecodeError) as e:
return self._error_result_handler(origin_prompt, logger) # 将解析错误包装成 ValueError 以便被 send 方法捕获并重试
return result raise ValueError(f"结果不能正确解析: {e.__repr__()}")
def _error_result_handler(self, origin_prompt: str, logger: Logger): def _error_result_handler(self, origin_prompt: str, logger: Logger):
if origin_prompt == "": if origin_prompt == "":
return [] return []
try: try:
return json_repair.loads(origin_prompt) return json_repair.loads(origin_prompt)
except: except (RuntimeError, JSONDecodeError):
logger.error("prompt不是json格式") logger.error(f"原始prompt不是有效的json格式: {origin_prompt}")
return origin_prompt return [] # 如果原始prompt也无效返回空列表
def send_segments(self, segments: list[str], chunk_size: int): def send_segments(self, segments: list[str], chunk_size: int):
self.logger.info(f"开始提取术语表,to_lang:{self.to_lang}") self.logger.info(f"开始提取术语表,to_lang:{self.to_lang}")
@@ -78,14 +81,17 @@ The output format should be plain JSON text in a list format
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)
for chunk in translated_chunks: for chunk in translated_chunks:
chunk: list[dict[str, str]]
try: try:
glossary_dict = {d["src"]: d["dst"] for d in chunk} if not isinstance(chunk, list):
self.logger.error(f"接收到的chunk不是有效的列表已跳过: {chunk}")
continue
glossary_dict = {d["src"]: d["dst"] for d in chunk if isinstance(d, dict) and "src" in d and "dst" in d}
result = glossary_dict | result result = glossary_dict | result
except JSONDecodeError as e: except (TypeError, KeyError) as e:
self.logger.info(f"json解析错误解析文本:{chunk}错误:{e.__repr__()}") self.logger.error(f"处理glossary chunk时发生键或类型错误已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
except Exception as e: except Exception as e:
self.logger.info(f"send_segments发生错误:{e.__repr__()}") self.logger.error(f"处理glossary chunk时发生未知错误: {e.__repr__()}")
self.logger.info("术语表提取完成") self.logger.info("术语表提取完成")
return result return result
@@ -99,14 +105,16 @@ The output format should be plain JSON text in a list format
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)
for chunk in translated_chunks: for chunk in translated_chunks:
chunk: list[dict[str, str]]
try: try:
glossary_dict = {d["src"]: d["dst"] for d in chunk} if not isinstance(chunk, list):
self.logger.error(f"接收到的chunk不是有效的列表已跳过: {chunk}")
continue
glossary_dict = {d["src"]: d["dst"] for d in chunk if isinstance(d, dict) and "src" in d and "dst" in d}
result = result | glossary_dict result = result | glossary_dict
except JSONDecodeError as e: except (TypeError, KeyError) as e:
self.logger.info(f"json解析错误解析文本:{chunk}错误:{e.__repr__()}") self.logger.error(f"处理glossary chunk时发生键或类型错误已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
except Exception as e: except Exception as e:
self.logger.info(f"send_segments发生错误:{e.__repr__()}") self.logger.error(f"处理glossary chunk时发生未知错误: {e.__repr__()}")
# print(f"术语表:\n{result}")
self.logger.info("术语表提取完成") self.logger.info("术语表提取完成")
return result return result

View File

@@ -57,7 +57,7 @@ $$1+1=2$$
\\((c_0,c_1,c_2^2)\\)是一个坐标。""" \\((c_0,c_1,c_2^2)\\)是一个坐标。"""
self.custom_prompt = config.custom_prompt self.custom_prompt = config.custom_prompt
if config.custom_prompt: if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n' self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\nEND\n'
self.glossary_dict = config.glossary_dict self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt): def _pre_send_handler(self, system_prompt, prompt):

View File

@@ -10,6 +10,7 @@ from logging import Logger
from json_repair import json_repair from json_repair import json_repair
from docutranslate.agents import AgentConfig, Agent from docutranslate.agents import AgentConfig, Agent
from docutranslate.agents.agent import PartialTranslationError
from docutranslate.glossary.glossary import Glossary from docutranslate.glossary.glossary import Glossary
from docutranslate.utils.json_utils import segments2json_chunks from docutranslate.utils.json_utils import segments2json_chunks
@@ -50,7 +51,7 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
""" """
self.custom_prompt = config.custom_prompt self.custom_prompt = config.custom_prompt
if config.custom_prompt: if config.custom_prompt:
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\n' self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\nEND\n'
self.glossary_dict = config.glossary_dict self.glossary_dict = config.glossary_dict
def _pre_send_handler(self, system_prompt, prompt): def _pre_send_handler(self, system_prompt, prompt):
@@ -60,94 +61,152 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
return system_prompt, prompt return system_prompt, prompt
def _result_handler(self, result: str, origin_prompt: str, logger: Logger): def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
"""
处理成功的API响应。
- 如果键完全匹配,返回翻译结果。
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
- 其他错误如JSON解析失败、模型偷懒则抛出普通 ValueError 触发重试。
"""
if result == "": if result == "":
if origin_prompt.strip() != "":
logger.error("result为空值但原文不为空")
raise ValueError("result为空值但原文不为空")
return {} return {}
try: try:
result = json_repair.loads(result) original_chunk = json.loads(origin_prompt)
if not isinstance(result, dict): repaired_result = json_repair.loads(result)
raise ValueError(f"agent返回结果不是dict的json形式,result:{result}")
except RuntimeError as e: if not isinstance(repaired_result, dict):
raise ValueError(f"结果不能正确解析:{e.__repr__()}") raise ValueError(f"Agent返回结果不是dict的json形式, result: {result}")
return result
if repaired_result == original_chunk:
raise ValueError("翻译结果与原文完全相同,判定为翻译失败,将进行重试。")
original_keys = set(original_chunk.keys())
result_keys = set(repaired_result.keys())
# 如果键不完全匹配
if original_keys != result_keys:
# 仍然先构造一个最完整的“部分结果”
final_chunk = {}
common_keys = original_keys.intersection(result_keys)
missing_keys = original_keys - result_keys
extra_keys = result_keys - original_keys
logger.warning(f"翻译结果的键与原文不匹配!将尝试重试。")
if missing_keys: logger.warning(f"缺失的键: {missing_keys}")
if extra_keys: logger.warning(f"多余的键: {extra_keys}")
for key in common_keys:
final_chunk[key] = str(repaired_result[key])
for key in missing_keys:
final_chunk[key] = str(original_chunk[key])
# 抛出自定义异常,将部分结果和错误信息一起传递出去
raise PartialTranslationError("键不匹配,触发重试", partial_result=final_chunk)
# 如果键完全匹配(理想情况),正常返回
for key, value in repaired_result.items():
repaired_result[key] = str(value)
return repaired_result
except (RuntimeError, JSONDecodeError) as e:
# 对于JSON解析等硬性错误继续抛出普通ValueError
raise ValueError(f"结果处理失败: {e.__repr__()}")
def _error_result_handler(self, origin_prompt: str, logger: Logger): def _error_result_handler(self, origin_prompt: str, logger: Logger):
"""
处理在所有重试后仍然失败的请求。
作为备用方案,返回原文内容,并将所有值转换为字符串。
"""
if origin_prompt == "": if origin_prompt == "":
return {} return {}
try: try:
return json_repair.loads(origin_prompt) original_chunk = json.loads(origin_prompt)
except: # 此处逻辑保留,作为最终的兜底方案
logger.error("prompt不是json格式") for key, value in original_chunk.items():
return origin_prompt original_chunk[key] = f"{value}"
return original_chunk
except (RuntimeError, JSONDecodeError):
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}")
# 如果原始prompt本身也无效返回一个清晰的错误对象
return {"error": f"{origin_prompt}"}
def send_segments(self, segments: list[str], chunk_size: int): def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy() indexed_translated = indexed_originals.copy()
for chunk in translated_chunks: for chunk in translated_chunks:
try: try:
if not isinstance(chunk, dict):
self.logger.warning(f"接收到的chunk不是有效的字典已跳过: {chunk}")
continue
for key, val in chunk.items(): for key, val in chunk.items():
if key in indexed_translated: if key in indexed_translated:
# 此处不再需要 str(val)
indexed_translated[key] = val indexed_translated[key] = val
except JSONDecodeError as e: else:
self.logger.info(f"json解析错误解析文本:{chunk},错误:{e.__repr__()}") self.logger.warning(f"在结果chunk中发现未知键 '{key}',已忽略。")
except ValueError as e: except (AttributeError, TypeError) as e:
self.logger.info(f"value错误更新对象:{indexed_translated}错误:{e.__repr__()}") self.logger.error(f"处理chunk时发生类型或属性错误已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
except Exception as e: except Exception as e:
self.logger.info(f"send_segments发生错误:{e.__repr__()}") self.logger.error(f"处理chunk时发生未知错误: {e.__repr__()}")
# 初始化结果列表 # 重建最终列表
result = [] result = []
last_end = 0 last_end = 0
ls = list(indexed_translated.values()) ls = list(indexed_translated.values())
for start, end in merged_indices_list: for start, end in merged_indices_list:
# 添加未处理的部分
result.extend(ls[last_end:start]) result.extend(ls[last_end:start])
# 合并切片范围内的元素 merged_item = "".join(map(str, ls[start:end]))
merged_item = "".join(ls[start:end])
result.append(merged_item) result.append(merged_item)
last_end = end last_end = end
# 添加剩余部分
result.extend(ls[last_end:]) result.extend(ls[last_end:])
return result return result
# todo:增加协程粒度 async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
async def send_segments_async(self, segments: list[str], chunk_size: int):
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments, indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
chunk_size) chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)
indexed_translated = indexed_originals.copy() indexed_translated = indexed_originals.copy()
for chunk in translated_chunks: for chunk in translated_chunks:
try: try:
if not isinstance(chunk, dict):
self.logger.error(f"接收到的chunk不是有效的字典已跳过: {chunk}")
continue
for key, val in chunk.items(): for key, val in chunk.items():
if key in indexed_translated: if key in indexed_translated:
indexed_translated[key] = str(val) # 此处不再需要 str(val),因为 _result_handler 已经处理好了
except JSONDecodeError as e: indexed_translated[key] = val
self.logger.info(f"json解析错误解析文本:{chunk},错误:{e.__repr__()}") else:
except ValueError as e: self.logger.warning(f"在结果chunk中发现未知键 '{key}',已忽略。")
self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}") except (AttributeError, TypeError) as e:
self.logger.error(f"处理chunk时发生类型或属性错误已跳过。Chunk: {chunk}, 错误: {e.__repr__()}")
except Exception as e: except Exception as e:
self.logger.info(f"send_segments发生错误:{e.__repr__()}") self.logger.error(f"处理chunk时发生未知错误: {e.__repr__()}")
# 初始化结果列表 # 重建最终列表
result = [] result = []
last_end = 0 last_end = 0
ls = list(indexed_translated.values()) ls = list(indexed_translated.values())
for start, end in merged_indices_list: for start, end in merged_indices_list:
# 添加未处理的部分
result.extend(ls[last_end:start]) result.extend(ls[last_end:start])
# 合并切片范围内的元素 merged_item = "".join(map(str, ls[start:end]))
merged_item = "".join(ls[start:end])
result.append(merged_item) result.append(merged_item)
last_end = end last_end = end
# 添加剩余部分
result.extend(ls[last_end:]) result.extend(ls[last_end:])
return result return result
@@ -155,4 +214,4 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
if self.glossary_dict is None: if self.glossary_dict is None:
self.glossary_dict = {} self.glossary_dict = {}
if update_dict is not None: if update_dict is not None:
self.glossary_dict = update_dict | self.glossary_dict self.glossary_dict = update_dict | self.glossary_dict

View File

@@ -12,86 +12,74 @@ def segments2json_chunks(segments: list[str], chunk_size_max: int) -> tuple[dict
list[dict[str, str]], list[tuple[int, int]]]: list[dict[str, str]], list[tuple[int, int]]]:
""" """
将文本段列表segments转换为多个JSON块。 将文本段列表segments转换为多个JSON块。
(函数注释不变)
功能描述:
1. 每个JSON块的格式为 {"序号0": "文本0", "序号1": "文本1", ...}。
2. 每个JSON块经过UTF-8编码后的字节大小不超过 chunk_size_max若单行文本就超出了chunk_size_max则保留单行文本
3. 如果单个文本段本身就超过大小限制,它将被自动分割成多个子文本段。
4. 返回值是一个元组,包含两个列表:
- json_chunks_list: 分块后的JSON字典列表。
- merged_indices_list: 一个元组列表,记录了被分割的文本段在新的序号系统中的起始和结束序号。
""" """
# === 第一部分:预处理将过长的segment拆分成更小的部分 === # === 第一部分:预处理 (这部分逻辑可以保持不变) ===
new_segments = [] new_segments = []
merged_indices_list = [] merged_indices_list = []
for segment in segments: for segment in segments:
# 检查单个segment作为一个JSON对象的值是否已超限 # 检查单个segment作为一个JSON对象的值是否已超限
if get_json_size({len(new_segments): segment}) > chunk_size_max: # 使用一个较长的key来预估避免key长度变化带来的误差
long_key_estimate = str(len(segments) + len(new_segments))
if get_json_size({long_key_estimate: segment}) > chunk_size_max:
sub_segments = [] sub_segments = []
lines = segment.splitlines(keepends=True) lines = segment.splitlines(keepends=True)
current_sub_segment = "" current_sub_segment = ""
for line in lines: for line in lines:
next_sub_segment = current_sub_segment + line next_sub_segment = current_sub_segment + line
# 预估下一个子段的大小 if get_json_size({long_key_estimate: next_sub_segment}) > chunk_size_max:
# 使用一个临时的key如0来模拟
if get_json_size({0: next_sub_segment}) > chunk_size_max:
# 如果 current_sub_segment 不为空,才将其添加
# 这可以防止因第一行就超限而添加一个空字符串
if current_sub_segment: if current_sub_segment:
sub_segments.append(current_sub_segment) sub_segments.append(current_sub_segment)
# 即使单行超限,也必须作为一个独立的子段添加 # 即使单行超限,也必须作为一个独立的子段添加
sub_segments.append(line) sub_segments.append(line)
current_sub_segment = "" # 重置 current_sub_segment = ""
else: else:
current_sub_segment = next_sub_segment current_sub_segment = next_sub_segment
# 不要忘记循环结束后剩余的部分
if current_sub_segment: if current_sub_segment:
sub_segments.append(current_sub_segment) sub_segments.append(current_sub_segment)
# 如果sub_segments为空例如原segment为空字符串则添加一个空字符串以保持一致性
if not sub_segments and segment == "": if not sub_segments and segment == "":
sub_segments.append("") sub_segments.append("")
start_index = len(new_segments) start_index = len(new_segments)
new_segments.extend(sub_segments) new_segments.extend(sub_segments)
end_index = len(new_segments) end_index = len(new_segments)
# 只有当一个segment被真正分割成多个部分时才记录
if end_index - start_index > 1: if end_index - start_index > 1:
merged_indices_list.append((start_index, end_index)) merged_indices_list.append((start_index, end_index))
else: else:
new_segments.append(segment) new_segments.append(segment)
# === 第二部分:将处理后的 new_segments 组合成 JSON 块 === # === 第二部分:组合成 JSON 块 (修正部分) ===
json_chunks_list = [] json_chunks_list = []
if not new_segments: # 处理输入为空列表的边缘情况 if not new_segments:
return {}, [], [] return {}, [], []
js={}
chunk = {} chunk = {}
for key, val in enumerate(new_segments): for key, val in enumerate(new_segments):
# 预先构建下一个块的样子来检查大小
prospective_chunk = chunk.copy() prospective_chunk = chunk.copy()
prospective_chunk[str(key)] = val prospective_chunk[str(key)] = val
# 检查 prospective_chunk 是否超限,并且当前 chunk 不为空 # 修复bug: 即使chunk为空如果 prospective_chunk即单个元素已超限
# 如果 chunk 为空,意味着这个 val 本身就超限了,但我们必须接受它, # 也应该先提交旧的chunk。
# 因为它已经是最小单位了。这可以防止产生空字典。
if get_json_size(prospective_chunk) > chunk_size_max and chunk: if get_json_size(prospective_chunk) > chunk_size_max and chunk:
json_chunks_list.append(chunk) # 将旧的、未超限的块存入列表 json_chunks_list.append(chunk)
chunk = {str(key): val} # 用当前元素开始一个新的块 chunk = {str(key): val}
else: else:
chunk = prospective_chunk # 未超限,更新块 chunk = prospective_chunk
js[str(key)]=val
# 循环结束后,将最后一个块加入列表
if chunk: if chunk:
json_chunks_list.append(chunk) json_chunks_list.append(chunk)
js.update(chunk)
# ==================== 核心修正 ====================
# 根据完整的 new_segments 列表构建最终的、完整的 js 字典
# 这确保了第一个返回值是完整的
js = {str(i): segment for i, segment in enumerate(new_segments)}
# ================================================
return js, json_chunks_list, merged_indices_list return js, json_chunks_list, merged_indices_list

View File

@@ -1,7 +1,7 @@
更新日志 更新日志
---------------------------------------- ----------------------------------------
v1.3.3版 2025.9.3 v1.3.3版 2025.9.3
优化 特性
- txt翻译支持设置插入模式 - txt翻译支持设置插入模式
---------------------------------------- ----------------------------------------
v1.3.2版 2025.9.2 v1.3.2版 2025.9.2