正确处理大的segments转chunks的情况

This commit is contained in:
xunbu
2025-08-18 09:33:05 +08:00
parent 7eeb9bb46d
commit a69c0562e6
5 changed files with 141 additions and 42 deletions

View File

@@ -1 +1 @@
__version__="1.1.3" __version__="1.1.4"

View File

@@ -1,10 +1,11 @@
import json import json
from json_repair import json_repair
from dataclasses import dataclass from dataclasses import dataclass
from json import JSONDecodeError from json import JSONDecodeError
from json_repair import json_repair
from docutranslate.agents import AgentConfig, Agent from docutranslate.agents import AgentConfig, Agent
from docutranslate.utils.json_utils import flat_json_split from docutranslate.utils.json_utils import flat_json_split, segments2json_chunks
@dataclass @dataclass
@@ -42,15 +43,14 @@ class SegmentsTranslateAgent(Agent):
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n' self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
def send_segments(self, segments: list[str], chunk_size: int): def send_segments(self, segments: list[str], chunk_size: int):
indexed_originals = {str(i): text for i, text in enumerate(segments)} indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
chunks = flat_json_split(indexed_originals, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts) translated_chunks = super().send_prompts(prompts=prompts)
indexed_translated = indexed_originals.copy() indexed_translated = indexed_originals.copy()
for chunk_str in translated_chunks: for chunk_str in translated_chunks:
try: try:
translated_part = json_repair.loads(chunk_str) translated_part = json_repair.loads(chunk_str)
for key,val in translated_part: for key, val in translated_part.items():
if key in indexed_translated: if key in indexed_translated:
indexed_translated[key] = val indexed_translated[key] = val
except JSONDecodeError as e: except JSONDecodeError as e:
@@ -58,26 +58,52 @@ class SegmentsTranslateAgent(Agent):
except ValueError as e: except ValueError as e:
self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}") self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}")
return list(indexed_translated.values()) # 初始化结果列表
result = []
last_end = 0
ls = list(indexed_translated.values())
for start, end in merged_indices_list:
# 添加未处理的部分
result.extend(ls[last_end:start])
# 合并切片范围内的元素
merged_item = "".join(ls[start:end])
result.append(merged_item)
last_end = end
# 添加剩余部分
result.extend(ls[last_end:])
return result
# todo:增加协程粒度 # todo:增加协程粒度
async def send_segments_async(self, segments: list[str], chunk_size: int): async def send_segments_async(self, segments: list[str], chunk_size: int):
indexed_originals = {str(i): text for i, text in enumerate(segments)} indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
chunks = flat_json_split(indexed_originals, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks] prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts) translated_chunks = await super().send_prompts_async(prompts=prompts)
indexed_translated = indexed_originals.copy() indexed_translated = indexed_originals.copy()
for chunk_str in translated_chunks: for chunk_str in translated_chunks:
try: try:
translated_part:dict = json_repair.loads(chunk_str) translated_part = json_repair.loads(chunk_str)
for key, val in translated_part.items(): for key, val in translated_part.items():
if key in indexed_translated: if key in indexed_translated:
indexed_translated[key] = val indexed_translated[key] = val
except JSONDecodeError as e: except JSONDecodeError as e:
self.logger.error(f"json解析错误解析文本:{chunk_str},错误:{e.__repr__()}") self.logger.info(f"json解析错误解析文本:{chunk_str},错误:{e.__repr__()}")
except ValueError as e: except ValueError as e:
self.logger.error(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}") self.logger.info(f"value错误更新对象:{indexed_translated},错误:{e.__repr__()}")
except AttributeError as e:
self.logger.error(f"属性错误,chunk_str:{chunk_str},错误:{e.__repr__()}")
return list(indexed_translated.values())
# 初始化结果列表
result = []
last_end = 0
ls=list(indexed_translated.values())
for start, end in merged_indices_list:
# 添加未处理的部分
result.extend(ls[last_end:start])
# 合并切片范围内的元素
merged_item = "".join(ls[start:end])
result.append(merged_item)
last_end = end
# 添加剩余部分
result.extend(ls[last_end:])
return result

View File

@@ -1,31 +1,98 @@
import json import json
def flat_json_split(js: dict, chunk_size_max: int) -> list[dict]: def get_json_size(js: dict) -> int:
"""计算字典转换成JSON字符串并以UTF-8编码后的字节大小"""
return len(json.dumps(js, ensure_ascii=False).encode('utf-8'))
def segments2json_chunks(segments: list[str], chunk_size_max: int) -> tuple[dict[str, str],
list[dict[str, str]], list[tuple[int, int]]]:
""" """
用给扁平的json形如{key:val}的分块每个分块大小不超过chunksize字节 将文本段列表segments转换为多个JSON块。
功能描述:
1. 每个JSON块的格式为 {"序号0": "文本0", "序号1": "文本1", ...}。
2. 每个JSON块经过UTF-8编码后的字节大小不超过 chunk_size_max若单行文本就超出了chunk_size_max则保留单行文本
3. 如果单个文本段本身就超过大小限制,它将被自动分割成多个子文本段。
4. 返回值是一个元组,包含两个列表:
- json_chunks_list: 分块后的JSON字典列表。
- merged_indices_list: 一个元组列表,记录了被分割的文本段在新的序号系统中的起始和结束序号。
""" """
chunks = []
chunk = {} # === 第一部分预处理将过长的segment拆分成更小的部分 ===
for key, val in js.items(): new_segments = []
t = chunk.copy() merged_indices_list = []
t[key] = val
chunk_size = get_json_size(t) for segment in segments:
if chunk_size <= chunk_size_max: # 检查单个segment作为一个JSON对象的值是否已超限
chunk[key] = val if get_json_size({len(new_segments): segment}) > chunk_size_max:
sub_segments = []
lines = segment.splitlines(keepends=True)
current_sub_segment = ""
for line in lines:
next_sub_segment = current_sub_segment + line
# 预估下一个子段的大小
# 使用一个临时的key如0来模拟
if get_json_size({0: next_sub_segment}) > chunk_size_max:
# 如果 current_sub_segment 不为空,才将其添加
# 这可以防止因第一行就超限而添加一个空字符串
if current_sub_segment:
sub_segments.append(current_sub_segment)
# 即使单行超限,也必须作为一个独立的子段添加
sub_segments.append(line)
current_sub_segment = "" # 重置
else: else:
chunks.append(chunk) current_sub_segment = next_sub_segment
chunk = {key:val}
chunks.append(chunk)
return chunks
# 不要忘记循环结束后剩余的部分
if current_sub_segment:
sub_segments.append(current_sub_segment)
# 如果sub_segments为空例如原segment为空字符串则添加一个空字符串以保持一致性
if not sub_segments and segment == "":
sub_segments.append("")
start_index = len(new_segments)
new_segments.extend(sub_segments)
end_index = len(new_segments)
# 只有当一个segment被真正分割成多个部分时才记录
if end_index - start_index > 1:
merged_indices_list.append((start_index, end_index))
else:
new_segments.append(segment)
# === 第二部分:将处理后的 new_segments 组合成 JSON 块 ===
json_chunks_list = []
if not new_segments: # 处理输入为空列表的边缘情况
return {}, [], []
js={}
chunk = {}
for key, val in enumerate(new_segments):
# 预先构建下一个块的样子来检查大小
prospective_chunk = chunk.copy()
prospective_chunk[str(key)] = val
# 检查 prospective_chunk 是否超限,并且当前 chunk 不为空
# 如果 chunk 为空,意味着这个 val 本身就超限了,但我们必须接受它,
# 因为它已经是最小单位了。这可以防止产生空字典。
if get_json_size(prospective_chunk) > chunk_size_max and chunk:
json_chunks_list.append(chunk) # 将旧的、未超限的块存入列表
chunk = {str(key): val} # 用当前元素开始一个新的块
else:
chunk = prospective_chunk # 未超限,更新块
js[str(key)]=val
# 循环结束后,将最后一个块加入列表
if chunk:
json_chunks_list.append(chunk)
js.update(chunk)
return js, json_chunks_list, merged_indices_list
def get_json_size(js: dict):
return len(json.dumps(js,ensure_ascii=False).encode())
if __name__ == '__main__': if __name__ == '__main__':
js={1:2,3:4,5:"哈哈"} print(get_json_size({"0": ""}))
ls=flat_json_split(js,30)
print(ls)
# for chunk in ls:
# print(len(chunk.encode()))

View File

@@ -14,6 +14,8 @@ dependencies = [
"mammoth>=1.10.0", "mammoth>=1.10.0",
"srt>=3.5.3", "srt>=3.5.3",
"lxml>=5.4.0", "lxml>=5.4.0",
"python-docx>=1.2.0",
"beautifulsoup4>=4.13.4",
] ]
dynamic = ["version"] dynamic = ["version"]

4
uv.lock generated
View File

@@ -315,6 +315,7 @@ wheels = [
name = "docutranslate" name = "docutranslate"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "beautifulsoup4" },
{ name = "fastapi", extra = ["standard"] }, { name = "fastapi", extra = ["standard"] },
{ name = "httpx" }, { name = "httpx" },
{ name = "json-repair" }, { name = "json-repair" },
@@ -323,6 +324,7 @@ dependencies = [
{ name = "mammoth" }, { name = "mammoth" },
{ name = "markdown2" }, { name = "markdown2" },
{ name = "openpyxl" }, { name = "openpyxl" },
{ name = "python-docx" },
{ name = "srt" }, { name = "srt" },
{ name = "xlsx2html" }, { name = "xlsx2html" },
] ]
@@ -342,6 +344,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "beautifulsoup4", specifier = ">=4.13.4" },
{ name = "docling", marker = "extra == 'docling'", specifier = ">=2.40.0" }, { name = "docling", marker = "extra == 'docling'", specifier = ">=2.40.0" },
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.12" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.12" },
{ name = "httpx", specifier = "==0.27.2" }, { name = "httpx", specifier = "==0.27.2" },
@@ -352,6 +355,7 @@ requires-dist = [
{ name = "markdown2", specifier = ">=2.5.3" }, { name = "markdown2", specifier = ">=2.5.3" },
{ name = "opencv-python", marker = "extra == 'docling'", specifier = ">=4.11.0.86" }, { name = "opencv-python", marker = "extra == 'docling'", specifier = ">=4.11.0.86" },
{ name = "openpyxl", specifier = ">=3.1.5" }, { name = "openpyxl", specifier = ">=3.1.5" },
{ name = "python-docx", specifier = ">=1.2.0" },
{ name = "srt", specifier = ">=3.5.3" }, { name = "srt", specifier = ">=3.5.3" },
{ name = "xlsx2html", specifier = ">=0.6.2" }, { name = "xlsx2html", specifier = ">=0.6.2" },
] ]