fix: MT批处理增加计数校验,不匹配时自动逐条回退
- MT_BATCH_SEP改为\n\n---\n\n,MT模型更不容易破坏 - _apply_mt_batch_results: 校验每批split count是否匹配预期 - 不匹配时标记为mismatched,_retranslate_mismatched逐条重译 - 批处理速度+逐条可靠性,两全兼顾 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -182,14 +182,16 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
|
logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
|
||||||
return {"error": f"{original_segments}"}
|
return {"error": f"{original_segments}"}
|
||||||
|
|
||||||
|
MT_BATCH_SEP = "\n\n---\n\n"
|
||||||
|
|
||||||
def _batch_segments_for_mt(self, segments: list[str], chunk_size: int) -> tuple[list[str], list[list[int]]]:
|
def _batch_segments_for_mt(self, segments: list[str], chunk_size: int) -> tuple[list[str], list[list[int]]]:
|
||||||
"""将 segments 按字符数分批,用 \\n\\n 自然段落分隔连接。返回(批文本列表, 每批的索引列表)。"""
|
"""将 segments 按字符数分批,用分隔符连接。返回(批文本列表, 每批的索引列表)。"""
|
||||||
batches = []
|
batches = []
|
||||||
index_groups = []
|
index_groups = []
|
||||||
current_parts = []
|
current_parts = []
|
||||||
current_indices = []
|
current_indices = []
|
||||||
current_size = 0
|
current_size = 0
|
||||||
sep = "\n\n"
|
sep = self.MT_BATCH_SEP
|
||||||
sep_size = len(sep.encode('utf-8'))
|
sep_size = len(sep.encode('utf-8'))
|
||||||
|
|
||||||
for i, seg in enumerate(segments):
|
for i, seg in enumerate(segments):
|
||||||
@@ -214,16 +216,66 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
return batches, index_groups
|
return batches, index_groups
|
||||||
|
|
||||||
def _mt_batch_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> list[str]:
|
def _mt_batch_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> list[str]:
|
||||||
"""MT batch: 按 \\n\\n 拆分翻译结果,恢复为独立段落。"""
|
"""MT batch: 按分隔符拆分翻译结果。"""
|
||||||
return [p.strip() for p in result.strip().split('\n\n')]
|
return [p.strip() for p in result.strip().split(self.MT_BATCH_SEP)]
|
||||||
|
|
||||||
def _mt_batch_error_handler(self, origin_prompt: str, logger: Logger) -> list[str]:
|
def _mt_batch_error_handler(self, origin_prompt: str, logger: Logger) -> list[str]:
|
||||||
"""MT batch error: 返回原文各段。"""
|
"""MT batch error: 返回原文各段。"""
|
||||||
return [p.strip() for p in origin_prompt.split('\n\n')]
|
return [p.strip() for p in origin_prompt.split(self.MT_BATCH_SEP)]
|
||||||
|
|
||||||
|
def _mt_individual_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> str:
|
||||||
|
"""MT individual: 直接返回翻译结果。"""
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
|
def _mt_individual_error_handler(self, origin_prompt: str, logger: Logger) -> str:
|
||||||
|
"""MT individual error: 返回原文。"""
|
||||||
|
return origin_prompt
|
||||||
|
|
||||||
|
def _apply_mt_batch_results(self, segments: list[str], batch_results: list,
|
||||||
|
batch_indices: list[list[int]]) -> list[str]:
|
||||||
|
"""应用批处理结果。对计数不匹配的批次,逐条回退重译。"""
|
||||||
|
all_translated = [""] * len(segments)
|
||||||
|
mismatch_batches = []
|
||||||
|
|
||||||
|
for batch_parts, indices in zip(batch_results, batch_indices):
|
||||||
|
if len(batch_parts) == len(indices):
|
||||||
|
for j, idx in enumerate(indices):
|
||||||
|
all_translated[idx] = batch_parts[j]
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"MT batch mismatch: got {len(batch_parts)} parts, expected {len(indices)}. "
|
||||||
|
f"Falling back to individual translation."
|
||||||
|
)
|
||||||
|
mismatch_batches.append(indices)
|
||||||
|
|
||||||
|
return all_translated, mismatch_batches
|
||||||
|
|
||||||
|
def _retranslate_mismatched(self, segments: list[str],
|
||||||
|
mismatch_batches: list[list[int]]) -> list[str]:
|
||||||
|
"""对计数不匹配的批次,逐条重新翻译。"""
|
||||||
|
# Collect all mismatched indices
|
||||||
|
all_mismatched = []
|
||||||
|
for indices in mismatch_batches:
|
||||||
|
all_mismatched.extend(indices)
|
||||||
|
|
||||||
|
if not all_mismatched:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.logger.info(f"Retranslating {len(all_mismatched)} mismatched segments individually")
|
||||||
|
individual_segments = [segments[i] for i in all_mismatched]
|
||||||
|
individual_results = super().send_prompts(
|
||||||
|
prompts=individual_segments,
|
||||||
|
result_handler=self._mt_individual_result_handler,
|
||||||
|
error_result_handler=self._mt_individual_error_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_map = {}
|
||||||
|
for idx, trans in zip(all_mismatched, individual_results):
|
||||||
|
result_map[idx] = trans
|
||||||
|
return result_map
|
||||||
|
|
||||||
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||||
if self.is_mt_mode:
|
if self.is_mt_mode:
|
||||||
# MT mode: batch segments by size, join with \n\n, split results back
|
|
||||||
if not segments:
|
if not segments:
|
||||||
return []
|
return []
|
||||||
batch_texts, batch_indices = self._batch_segments_for_mt(segments, chunk_size)
|
batch_texts, batch_indices = self._batch_segments_for_mt(segments, chunk_size)
|
||||||
@@ -232,10 +284,13 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
result_handler=self._mt_batch_result_handler,
|
result_handler=self._mt_batch_result_handler,
|
||||||
error_result_handler=self._mt_batch_error_handler,
|
error_result_handler=self._mt_batch_error_handler,
|
||||||
)
|
)
|
||||||
all_translated = [""] * len(segments)
|
all_translated, mismatched = self._apply_mt_batch_results(
|
||||||
for batch_parts, indices in zip(batch_results, batch_indices):
|
segments, batch_results, batch_indices
|
||||||
for j, idx in enumerate(indices):
|
)
|
||||||
all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx]
|
if mismatched:
|
||||||
|
retranslated = self._retranslate_mismatched(segments, mismatched)
|
||||||
|
for idx, trans in retranslated.items():
|
||||||
|
all_translated[idx] = trans
|
||||||
return all_translated
|
return all_translated
|
||||||
|
|
||||||
# Non-MT mode: JSON batch translation
|
# Non-MT mode: JSON batch translation
|
||||||
@@ -276,7 +331,6 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
|
|
||||||
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
|
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||||
if self.is_mt_mode:
|
if self.is_mt_mode:
|
||||||
# MT mode: batch segments by size, join with \n\n, split results back
|
|
||||||
if not segments:
|
if not segments:
|
||||||
return []
|
return []
|
||||||
batch_texts, batch_indices = await asyncio.to_thread(
|
batch_texts, batch_indices = await asyncio.to_thread(
|
||||||
@@ -287,10 +341,15 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
result_handler=self._mt_batch_result_handler,
|
result_handler=self._mt_batch_result_handler,
|
||||||
error_result_handler=self._mt_batch_error_handler,
|
error_result_handler=self._mt_batch_error_handler,
|
||||||
)
|
)
|
||||||
all_translated = [""] * len(segments)
|
all_translated, mismatched = self._apply_mt_batch_results(
|
||||||
for batch_parts, indices in zip(batch_results, batch_indices):
|
segments, batch_results, batch_indices
|
||||||
for j, idx in enumerate(indices):
|
)
|
||||||
all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx]
|
if mismatched:
|
||||||
|
retranslated = await asyncio.to_thread(
|
||||||
|
self._retranslate_mismatched, segments, mismatched
|
||||||
|
)
|
||||||
|
for idx, trans in retranslated.items():
|
||||||
|
all_translated[idx] = trans
|
||||||
return all_translated
|
return all_translated
|
||||||
|
|
||||||
# Non-MT mode: JSON batch translation
|
# Non-MT mode: JSON batch translation
|
||||||
|
|||||||
Reference in New Issue
Block a user