From 4f6bd1bc7b504c820fd3c3bce04a9896dd9d52cd Mon Sep 17 00:00:00 2001 From: Leon Date: Mon, 8 Jun 2026 16:10:14 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20MT=E6=89=B9=E5=A4=84=E7=90=86=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E8=AE=A1=E6=95=B0=E6=A0=A1=E9=AA=8C=EF=BC=8C=E4=B8=8D?= =?UTF-8?q?=E5=8C=B9=E9=85=8D=E6=97=B6=E8=87=AA=E5=8A=A8=E9=80=90=E6=9D=A1?= =?UTF-8?q?=E5=9B=9E=E9=80=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MT_BATCH_SEP改为\n\n---\n\n,MT模型更不容易破坏 - _apply_mt_batch_results: 校验每批split count是否匹配预期 - 不匹配时标记为mismatched,_retranslate_mismatched逐条重译 - 批处理速度+逐条可靠性,两全兼顾 Co-Authored-By: Claude Opus 4.7 --- docutranslate/agents/segments_agent.py | 89 +++++++++++++++++++++----- 1 file changed, 74 insertions(+), 15 deletions(-) diff --git a/docutranslate/agents/segments_agent.py b/docutranslate/agents/segments_agent.py index e023941..a70d733 100644 --- a/docutranslate/agents/segments_agent.py +++ b/docutranslate/agents/segments_agent.py @@ -182,14 +182,16 @@ class SegmentsTranslateAgent(Agent): logger.error(f"原始prompt也不是有效的json格式: {original_segments}") return {"error": f"{original_segments}"} + MT_BATCH_SEP = "\n\n---\n\n" + def _batch_segments_for_mt(self, segments: list[str], chunk_size: int) -> tuple[list[str], list[list[int]]]: - """将 segments 按字符数分批,用 \\n\\n 自然段落分隔连接。返回(批文本列表, 每批的索引列表)。""" + """将 segments 按字符数分批,用分隔符连接。返回(批文本列表, 每批的索引列表)。""" batches = [] index_groups = [] current_parts = [] current_indices = [] current_size = 0 - sep = "\n\n" + sep = self.MT_BATCH_SEP sep_size = len(sep.encode('utf-8')) for i, seg in enumerate(segments): @@ -214,16 +216,66 @@ class SegmentsTranslateAgent(Agent): return batches, index_groups def _mt_batch_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> list[str]: - """MT batch: 按 \\n\\n 拆分翻译结果,恢复为独立段落。""" - return [p.strip() for p in result.strip().split('\n\n')] + """MT batch: 按分隔符拆分翻译结果。""" + return [p.strip() for p in result.strip().split(self.MT_BATCH_SEP)] def _mt_batch_error_handler(self, origin_prompt: str, logger: Logger) -> list[str]: """MT batch error: 返回原文各段。""" - return [p.strip() for p in origin_prompt.split('\n\n')] + return [p.strip() for p in origin_prompt.split(self.MT_BATCH_SEP)] + + def _mt_individual_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> str: + """MT individual: 直接返回翻译结果。""" + return result.strip() + + def _mt_individual_error_handler(self, origin_prompt: str, logger: Logger) -> str: + """MT individual error: 返回原文。""" + return origin_prompt + + def _apply_mt_batch_results(self, segments: list[str], batch_results: list, + batch_indices: list[list[int]]) -> list[str]: + """应用批处理结果。对计数不匹配的批次,逐条回退重译。""" + all_translated = [""] * len(segments) + mismatch_batches = [] + + for batch_parts, indices in zip(batch_results, batch_indices): + if len(batch_parts) == len(indices): + for j, idx in enumerate(indices): + all_translated[idx] = batch_parts[j] + else: + self.logger.warning( + f"MT batch mismatch: got {len(batch_parts)} parts, expected {len(indices)}. " + f"Falling back to individual translation." + ) + mismatch_batches.append(indices) + + return all_translated, mismatch_batches + + def _retranslate_mismatched(self, segments: list[str], + mismatch_batches: list[list[int]]) -> list[str]: + """对计数不匹配的批次,逐条重新翻译。""" + # Collect all mismatched indices + all_mismatched = [] + for indices in mismatch_batches: + all_mismatched.extend(indices) + + if not all_mismatched: + return [] + + self.logger.info(f"Retranslating {len(all_mismatched)} mismatched segments individually") + individual_segments = [segments[i] for i in all_mismatched] + individual_results = super().send_prompts( + prompts=individual_segments, + result_handler=self._mt_individual_result_handler, + error_result_handler=self._mt_individual_error_handler, + ) + + result_map = {} + for idx, trans in zip(all_mismatched, individual_results): + result_map[idx] = trans + return result_map def send_segments(self, segments: list[str], chunk_size: int) -> list[str]: if self.is_mt_mode: - # MT mode: batch segments by size, join with \n\n, split results back if not segments: return [] batch_texts, batch_indices = self._batch_segments_for_mt(segments, chunk_size) @@ -232,10 +284,13 @@ class SegmentsTranslateAgent(Agent): result_handler=self._mt_batch_result_handler, error_result_handler=self._mt_batch_error_handler, ) - all_translated = [""] * len(segments) - for batch_parts, indices in zip(batch_results, batch_indices): - for j, idx in enumerate(indices): - all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx] + all_translated, mismatched = self._apply_mt_batch_results( + segments, batch_results, batch_indices + ) + if mismatched: + retranslated = self._retranslate_mismatched(segments, mismatched) + for idx, trans in retranslated.items(): + all_translated[idx] = trans return all_translated # Non-MT mode: JSON batch translation @@ -276,7 +331,6 @@ class SegmentsTranslateAgent(Agent): async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]: if self.is_mt_mode: - # MT mode: batch segments by size, join with \n\n, split results back if not segments: return [] batch_texts, batch_indices = await asyncio.to_thread( @@ -287,10 +341,15 @@ class SegmentsTranslateAgent(Agent): result_handler=self._mt_batch_result_handler, error_result_handler=self._mt_batch_error_handler, ) - all_translated = [""] * len(segments) - for batch_parts, indices in zip(batch_results, batch_indices): - for j, idx in enumerate(indices): - all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx] + all_translated, mismatched = self._apply_mt_batch_results( + segments, batch_results, batch_indices + ) + if mismatched: + retranslated = await asyncio.to_thread( + self._retranslate_mismatched, segments, mismatched + ) + for idx, trans in retranslated.items(): + all_translated[idx] = trans return all_translated # Non-MT mode: JSON batch translation