fix: MT批处理+原项目功能合并,616段→21批

- segments_agent.py: MT模式用\n\n自然段落分隔批处理替代逐条发送
  _batch_segments_for_mt: 按chunk_size分批,\n\n连接段落
  _mt_batch_result_handler: 按\n\n拆分翻译结果回映射
  616段→21批(减少96.6% API调用),翻译速度从~6分钟→~1分钟
- docx_translator.py: 合并原项目功能
  +is_instr_text_run: 跳过w:instrText域代码,防止TOC/页码被破坏
  +_decrypt_if_needed: 支持密码加密的docx
  +office_password配置项

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-08 15:49:01 +08:00
parent 4cf1a8c67d
commit a8b8c416dd
2 changed files with 101 additions and 19 deletions

View File

@@ -182,22 +182,61 @@ class SegmentsTranslateAgent(Agent):
logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
return {"error": f"{original_segments}"}
def _mt_simple_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> str:
"""MT mode: 直接返回翻译结果,不解析标记/JSON"""
return result.strip()
def _batch_segments_for_mt(self, segments: list[str], chunk_size: int) -> tuple[list[str], list[list[int]]]:
"""将 segments 按字符数分批,用 \\n\\n 自然段落分隔连接。返回(批文本列表, 每批的索引列表)"""
batches = []
index_groups = []
current_parts = []
current_indices = []
current_size = 0
sep = "\n\n"
sep_size = len(sep.encode('utf-8'))
def _mt_simple_error_handler(self, origin_prompt: str, logger: Logger) -> str:
"""MT mode error fallback: 返回原文。"""
return origin_prompt
for i, seg in enumerate(segments):
seg_size = len(seg.encode('utf-8'))
add_size = (sep_size if current_parts else 0) + seg_size
if current_parts and current_size + add_size > chunk_size:
batches.append(sep.join(current_parts))
index_groups.append(current_indices)
current_parts = [seg]
current_indices = [i]
current_size = seg_size
else:
current_parts.append(seg)
current_indices.append(i)
current_size += add_size
if current_parts:
batches.append(sep.join(current_parts))
index_groups.append(current_indices)
return batches, index_groups
def _mt_batch_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> list[str]:
"""MT batch: 按 \\n\\n 拆分翻译结果,恢复为独立段落。"""
return [p.strip() for p in result.strip().split('\n\n')]
def _mt_batch_error_handler(self, origin_prompt: str, logger: Logger) -> list[str]:
"""MT batch error: 返回原文各段。"""
return [p.strip() for p in origin_prompt.split('\n\n')]
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
if self.is_mt_mode:
# MT mode: send each segment individually as plain text, no markers, no batching
return super().send_prompts(
prompts=segments,
result_handler=self._mt_simple_result_handler,
error_result_handler=self._mt_simple_error_handler,
# MT mode: batch segments by size, join with \n\n, split results back
if not segments:
return []
batch_texts, batch_indices = self._batch_segments_for_mt(segments, chunk_size)
batch_results = super().send_prompts(
prompts=batch_texts,
result_handler=self._mt_batch_result_handler,
error_result_handler=self._mt_batch_error_handler,
)
all_translated = [""] * len(segments)
for batch_parts, indices in zip(batch_results, batch_indices):
for j, idx in enumerate(indices):
all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx]
return all_translated
# Non-MT mode: JSON batch translation
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
@@ -237,12 +276,22 @@ class SegmentsTranslateAgent(Agent):
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
if self.is_mt_mode:
# MT mode: send each segment individually as plain text, no markers, no batching
return await super().send_prompts_async(
prompts=segments,
result_handler=self._mt_simple_result_handler,
error_result_handler=self._mt_simple_error_handler,
# MT mode: batch segments by size, join with \n\n, split results back
if not segments:
return []
batch_texts, batch_indices = await asyncio.to_thread(
self._batch_segments_for_mt, segments, chunk_size
)
batch_results = await super().send_prompts_async(
prompts=batch_texts,
result_handler=self._mt_batch_result_handler,
error_result_handler=self._mt_batch_error_handler,
)
all_translated = [""] * len(segments)
for batch_parts, indices in zip(batch_results, batch_indices):
for j, idx in enumerate(indices):
all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx]
return all_translated
# Non-MT mode: JSON batch translation
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import Self, Literal, List, Dict, Any, Tuple
from typing import Self, Literal, List, Dict, Any, Tuple, Optional
import docx
from docx.document import Document as DocumentObject
@@ -32,11 +32,21 @@ def is_image_run(run: Run) -> bool:
return '<w:drawing' in xml or '<w:pict' in xml
def is_instr_text_run(run: Run) -> bool:
"""
检查 Run 是否包含域指令文本 (w:instrText)。
目录(TOC)、页码等功能的指令代码存储在此标签中。
必须跳过这些 Run否则写入 text 会破坏域结构。
"""
return run.element.find(qn('w:instrText')) is not None
# ---------------- 配置类 ----------------
@dataclass
class DocxTranslatorConfig(AiTranslatorConfig):
insert_mode: Literal["replace", "append", "prepend"] = "replace"
separator: str = "\n"
office_password: Optional[str] = None
# ---------------- 主类 ----------------
@@ -90,6 +100,28 @@ class DocxTranslator(AiTranslator):
self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode
self.separator = config.separator
self.office_password = config.office_password
def _decrypt_if_needed(self, content: bytes) -> bytes:
"""如果文件加密则解密,否则返回原内容。"""
try:
import msoffcrypto
from io import BytesIO as BIO
file_stream = BIO(content)
try:
office_file = msoffcrypto.OfficeFile(file_stream)
if office_file.is_encrypted():
if not self.office_password:
raise ValueError("此DOCX文件已加密但未提供密码。")
decrypted = BIO()
office_file.load_key(password=self.office_password)
office_file.decrypt(decrypted)
return decrypted.getvalue()
return content
finally:
file_stream.close()
except ImportError:
return content
@staticmethod
def _run_format_key(run: Run):
@@ -143,7 +175,7 @@ class DocxTranslator(AiTranslator):
text_runs = []
for run in para.runs:
if is_image_run(run):
if is_image_run(run) or is_instr_text_run(run):
continue
if not run.text.strip():
continue
@@ -203,7 +235,8 @@ class DocxTranslator(AiTranslator):
self._process_body_elements(parent_element, container, elements, texts)
def _pre_translate(self, document: Document) -> Tuple[DocumentObject, List[Dict[str, Any]], List[str]]:
doc = docx.Document(BytesIO(document.content))
content = self._decrypt_if_needed(document.content)
doc = docx.Document(BytesIO(content))
elements, texts = [], []
self._traverse_container(doc, elements, texts)