fix: MT批处理+原项目功能合并,616段→21批
- segments_agent.py: MT模式用\n\n自然段落分隔批处理替代逐条发送 _batch_segments_for_mt: 按chunk_size分批,\n\n连接段落 _mt_batch_result_handler: 按\n\n拆分翻译结果回映射 616段→21批(减少96.6% API调用),翻译速度从~6分钟→~1分钟 - docx_translator.py: 合并原项目功能 +is_instr_text_run: 跳过w:instrText域代码,防止TOC/页码被破坏 +_decrypt_if_needed: 支持密码加密的docx +office_password配置项 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -182,22 +182,61 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
|
logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
|
||||||
return {"error": f"{original_segments}"}
|
return {"error": f"{original_segments}"}
|
||||||
|
|
||||||
def _mt_simple_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> str:
|
def _batch_segments_for_mt(self, segments: list[str], chunk_size: int) -> tuple[list[str], list[list[int]]]:
|
||||||
"""MT mode: 直接返回翻译结果,不解析标记/JSON。"""
|
"""将 segments 按字符数分批,用 \\n\\n 自然段落分隔连接。返回(批文本列表, 每批的索引列表)。"""
|
||||||
return result.strip()
|
batches = []
|
||||||
|
index_groups = []
|
||||||
|
current_parts = []
|
||||||
|
current_indices = []
|
||||||
|
current_size = 0
|
||||||
|
sep = "\n\n"
|
||||||
|
sep_size = len(sep.encode('utf-8'))
|
||||||
|
|
||||||
def _mt_simple_error_handler(self, origin_prompt: str, logger: Logger) -> str:
|
for i, seg in enumerate(segments):
|
||||||
"""MT mode error fallback: 返回原文。"""
|
seg_size = len(seg.encode('utf-8'))
|
||||||
return origin_prompt
|
add_size = (sep_size if current_parts else 0) + seg_size
|
||||||
|
|
||||||
|
if current_parts and current_size + add_size > chunk_size:
|
||||||
|
batches.append(sep.join(current_parts))
|
||||||
|
index_groups.append(current_indices)
|
||||||
|
current_parts = [seg]
|
||||||
|
current_indices = [i]
|
||||||
|
current_size = seg_size
|
||||||
|
else:
|
||||||
|
current_parts.append(seg)
|
||||||
|
current_indices.append(i)
|
||||||
|
current_size += add_size
|
||||||
|
|
||||||
|
if current_parts:
|
||||||
|
batches.append(sep.join(current_parts))
|
||||||
|
index_groups.append(current_indices)
|
||||||
|
|
||||||
|
return batches, index_groups
|
||||||
|
|
||||||
|
def _mt_batch_result_handler(self, result: str, origin_prompt: str, logger: Logger) -> list[str]:
|
||||||
|
"""MT batch: 按 \\n\\n 拆分翻译结果,恢复为独立段落。"""
|
||||||
|
return [p.strip() for p in result.strip().split('\n\n')]
|
||||||
|
|
||||||
|
def _mt_batch_error_handler(self, origin_prompt: str, logger: Logger) -> list[str]:
|
||||||
|
"""MT batch error: 返回原文各段。"""
|
||||||
|
return [p.strip() for p in origin_prompt.split('\n\n')]
|
||||||
|
|
||||||
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||||
if self.is_mt_mode:
|
if self.is_mt_mode:
|
||||||
# MT mode: send each segment individually as plain text, no markers, no batching
|
# MT mode: batch segments by size, join with \n\n, split results back
|
||||||
return super().send_prompts(
|
if not segments:
|
||||||
prompts=segments,
|
return []
|
||||||
result_handler=self._mt_simple_result_handler,
|
batch_texts, batch_indices = self._batch_segments_for_mt(segments, chunk_size)
|
||||||
error_result_handler=self._mt_simple_error_handler,
|
batch_results = super().send_prompts(
|
||||||
|
prompts=batch_texts,
|
||||||
|
result_handler=self._mt_batch_result_handler,
|
||||||
|
error_result_handler=self._mt_batch_error_handler,
|
||||||
)
|
)
|
||||||
|
all_translated = [""] * len(segments)
|
||||||
|
for batch_parts, indices in zip(batch_results, batch_indices):
|
||||||
|
for j, idx in enumerate(indices):
|
||||||
|
all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx]
|
||||||
|
return all_translated
|
||||||
|
|
||||||
# Non-MT mode: JSON batch translation
|
# Non-MT mode: JSON batch translation
|
||||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||||
@@ -237,12 +276,22 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
|
|
||||||
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
|
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||||
if self.is_mt_mode:
|
if self.is_mt_mode:
|
||||||
# MT mode: send each segment individually as plain text, no markers, no batching
|
# MT mode: batch segments by size, join with \n\n, split results back
|
||||||
return await super().send_prompts_async(
|
if not segments:
|
||||||
prompts=segments,
|
return []
|
||||||
result_handler=self._mt_simple_result_handler,
|
batch_texts, batch_indices = await asyncio.to_thread(
|
||||||
error_result_handler=self._mt_simple_error_handler,
|
self._batch_segments_for_mt, segments, chunk_size
|
||||||
)
|
)
|
||||||
|
batch_results = await super().send_prompts_async(
|
||||||
|
prompts=batch_texts,
|
||||||
|
result_handler=self._mt_batch_result_handler,
|
||||||
|
error_result_handler=self._mt_batch_error_handler,
|
||||||
|
)
|
||||||
|
all_translated = [""] * len(segments)
|
||||||
|
for batch_parts, indices in zip(batch_results, batch_indices):
|
||||||
|
for j, idx in enumerate(indices):
|
||||||
|
all_translated[idx] = batch_parts[j] if j < len(batch_parts) else segments[idx]
|
||||||
|
return all_translated
|
||||||
|
|
||||||
# Non-MT mode: JSON batch translation
|
# Non-MT mode: JSON batch translation
|
||||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Self, Literal, List, Dict, Any, Tuple
|
from typing import Self, Literal, List, Dict, Any, Tuple, Optional
|
||||||
|
|
||||||
import docx
|
import docx
|
||||||
from docx.document import Document as DocumentObject
|
from docx.document import Document as DocumentObject
|
||||||
@@ -32,11 +32,21 @@ def is_image_run(run: Run) -> bool:
|
|||||||
return '<w:drawing' in xml or '<w:pict' in xml
|
return '<w:drawing' in xml or '<w:pict' in xml
|
||||||
|
|
||||||
|
|
||||||
|
def is_instr_text_run(run: Run) -> bool:
|
||||||
|
"""
|
||||||
|
检查 Run 是否包含域指令文本 (w:instrText)。
|
||||||
|
目录(TOC)、页码等功能的指令代码存储在此标签中。
|
||||||
|
必须跳过这些 Run,否则写入 text 会破坏域结构。
|
||||||
|
"""
|
||||||
|
return run.element.find(qn('w:instrText')) is not None
|
||||||
|
|
||||||
|
|
||||||
# ---------------- 配置类 ----------------
|
# ---------------- 配置类 ----------------
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocxTranslatorConfig(AiTranslatorConfig):
|
class DocxTranslatorConfig(AiTranslatorConfig):
|
||||||
insert_mode: Literal["replace", "append", "prepend"] = "replace"
|
insert_mode: Literal["replace", "append", "prepend"] = "replace"
|
||||||
separator: str = "\n"
|
separator: str = "\n"
|
||||||
|
office_password: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ---------------- 主类 ----------------
|
# ---------------- 主类 ----------------
|
||||||
@@ -90,6 +100,28 @@ class DocxTranslator(AiTranslator):
|
|||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
self.separator = config.separator
|
self.separator = config.separator
|
||||||
|
self.office_password = config.office_password
|
||||||
|
|
||||||
|
def _decrypt_if_needed(self, content: bytes) -> bytes:
|
||||||
|
"""如果文件加密则解密,否则返回原内容。"""
|
||||||
|
try:
|
||||||
|
import msoffcrypto
|
||||||
|
from io import BytesIO as BIO
|
||||||
|
file_stream = BIO(content)
|
||||||
|
try:
|
||||||
|
office_file = msoffcrypto.OfficeFile(file_stream)
|
||||||
|
if office_file.is_encrypted():
|
||||||
|
if not self.office_password:
|
||||||
|
raise ValueError("此DOCX文件已加密,但未提供密码。")
|
||||||
|
decrypted = BIO()
|
||||||
|
office_file.load_key(password=self.office_password)
|
||||||
|
office_file.decrypt(decrypted)
|
||||||
|
return decrypted.getvalue()
|
||||||
|
return content
|
||||||
|
finally:
|
||||||
|
file_stream.close()
|
||||||
|
except ImportError:
|
||||||
|
return content
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _run_format_key(run: Run):
|
def _run_format_key(run: Run):
|
||||||
@@ -143,7 +175,7 @@ class DocxTranslator(AiTranslator):
|
|||||||
|
|
||||||
text_runs = []
|
text_runs = []
|
||||||
for run in para.runs:
|
for run in para.runs:
|
||||||
if is_image_run(run):
|
if is_image_run(run) or is_instr_text_run(run):
|
||||||
continue
|
continue
|
||||||
if not run.text.strip():
|
if not run.text.strip():
|
||||||
continue
|
continue
|
||||||
@@ -203,7 +235,8 @@ class DocxTranslator(AiTranslator):
|
|||||||
self._process_body_elements(parent_element, container, elements, texts)
|
self._process_body_elements(parent_element, container, elements, texts)
|
||||||
|
|
||||||
def _pre_translate(self, document: Document) -> Tuple[DocumentObject, List[Dict[str, Any]], List[str]]:
|
def _pre_translate(self, document: Document) -> Tuple[DocumentObject, List[Dict[str, Any]], List[str]]:
|
||||||
doc = docx.Document(BytesIO(document.content))
|
content = self._decrypt_if_needed(document.content)
|
||||||
|
doc = docx.Document(BytesIO(content))
|
||||||
elements, texts = [], []
|
elements, texts = [], []
|
||||||
|
|
||||||
self._traverse_container(doc, elements, texts)
|
self._traverse_container(doc, elements, texts)
|
||||||
|
|||||||
Reference in New Issue
Block a user