正确处理大的segments转chunks的情况
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__="1.1.3"
|
||||
__version__="1.1.4"
|
||||
@@ -1,10 +1,11 @@
|
||||
import json
|
||||
from json_repair import json_repair
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError
|
||||
|
||||
from json_repair import json_repair
|
||||
|
||||
from docutranslate.agents import AgentConfig, Agent
|
||||
from docutranslate.utils.json_utils import flat_json_split
|
||||
from docutranslate.utils.json_utils import flat_json_split, segments2json_chunks
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,42 +43,67 @@ class SegmentsTranslateAgent(Agent):
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||
|
||||
def send_segments(self, segments: list[str], chunk_size: int):
|
||||
indexed_originals = {str(i): text for i, text in enumerate(segments)}
|
||||
chunks = flat_json_split(indexed_originals, chunk_size)
|
||||
prompts = [json.dumps(chunk,ensure_ascii=False) for chunk in chunks]
|
||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
translated_chunks = super().send_prompts(prompts=prompts)
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk_str in translated_chunks:
|
||||
try:
|
||||
translated_part = json_repair.loads(chunk_str)
|
||||
for key,val in translated_part:
|
||||
for key, val in translated_part.items():
|
||||
if key in indexed_translated:
|
||||
indexed_translated[key]=val
|
||||
indexed_translated[key] = val
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
||||
except ValueError as e:
|
||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
|
||||
return list(indexed_translated.values())
|
||||
# 初始化结果列表
|
||||
result = []
|
||||
last_end = 0
|
||||
ls = list(indexed_translated.values())
|
||||
for start, end in merged_indices_list:
|
||||
# 添加未处理的部分
|
||||
result.extend(ls[last_end:start])
|
||||
# 合并切片范围内的元素
|
||||
merged_item = "".join(ls[start:end])
|
||||
result.append(merged_item)
|
||||
last_end = end
|
||||
|
||||
#todo:增加协程粒度
|
||||
# 添加剩余部分
|
||||
result.extend(ls[last_end:])
|
||||
return result
|
||||
|
||||
# todo:增加协程粒度
|
||||
async def send_segments_async(self, segments: list[str], chunk_size: int):
|
||||
indexed_originals = {str(i): text for i, text in enumerate(segments)}
|
||||
chunks = flat_json_split(indexed_originals, chunk_size)
|
||||
prompts = [json.dumps(chunk,ensure_ascii=False) for chunk in chunks]
|
||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts)
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk_str in translated_chunks:
|
||||
try:
|
||||
translated_part:dict = json_repair.loads(chunk_str)
|
||||
for key,val in translated_part.items():
|
||||
translated_part = json_repair.loads(chunk_str)
|
||||
for key, val in translated_part.items():
|
||||
if key in indexed_translated:
|
||||
indexed_translated[key]=val
|
||||
indexed_translated[key] = val
|
||||
except JSONDecodeError as e:
|
||||
self.logger.error(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
||||
except ValueError as e:
|
||||
self.logger.error(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
except AttributeError as e:
|
||||
self.logger.error(f"属性错误,chunk_str:{chunk_str},错误:{e.__repr__()}")
|
||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
|
||||
return list(indexed_translated.values())
|
||||
|
||||
# 初始化结果列表
|
||||
result = []
|
||||
last_end = 0
|
||||
ls=list(indexed_translated.values())
|
||||
for start, end in merged_indices_list:
|
||||
# 添加未处理的部分
|
||||
result.extend(ls[last_end:start])
|
||||
# 合并切片范围内的元素
|
||||
merged_item = "".join(ls[start:end])
|
||||
result.append(merged_item)
|
||||
last_end = end
|
||||
|
||||
# 添加剩余部分
|
||||
result.extend(ls[last_end:])
|
||||
return result
|
||||
|
||||
@@ -1,31 +1,98 @@
|
||||
import json
|
||||
|
||||
|
||||
def flat_json_split(js: dict, chunk_size_max: int) -> list[dict]:
|
||||
def get_json_size(js: dict) -> int:
|
||||
"""计算字典转换成JSON字符串并以UTF-8编码后的字节大小"""
|
||||
return len(json.dumps(js, ensure_ascii=False).encode('utf-8'))
|
||||
|
||||
|
||||
def segments2json_chunks(segments: list[str], chunk_size_max: int) -> tuple[dict[str, str],
|
||||
list[dict[str, str]], list[tuple[int, int]]]:
|
||||
"""
|
||||
用给扁平的json(形如{key:val})的分块,每个分块大小不超过chunksize字节
|
||||
将文本段列表(segments)转换为多个JSON块。
|
||||
|
||||
功能描述:
|
||||
1. 每个JSON块的格式为 {"序号0": "文本0", "序号1": "文本1", ...}。
|
||||
2. 每个JSON块经过UTF-8编码后的字节大小不超过 chunk_size_max(若单行文本就超出了chunk_size_max则保留单行文本)。
|
||||
3. 如果单个文本段本身就超过大小限制,它将被自动分割成多个子文本段。
|
||||
4. 返回值是一个元组,包含两个列表:
|
||||
- json_chunks_list: 分块后的JSON字典列表。
|
||||
- merged_indices_list: 一个元组列表,记录了被分割的文本段在新的序号系统中的起始和结束序号。
|
||||
"""
|
||||
chunks = []
|
||||
chunk = {}
|
||||
for key, val in js.items():
|
||||
t = chunk.copy()
|
||||
t[key] = val
|
||||
chunk_size = get_json_size(t)
|
||||
if chunk_size <= chunk_size_max:
|
||||
chunk[key] = val
|
||||
|
||||
# === 第一部分:预处理,将过长的segment拆分成更小的部分 ===
|
||||
new_segments = []
|
||||
merged_indices_list = []
|
||||
|
||||
for segment in segments:
|
||||
# 检查单个segment(作为一个JSON对象的值)是否已超限
|
||||
if get_json_size({len(new_segments): segment}) > chunk_size_max:
|
||||
sub_segments = []
|
||||
lines = segment.splitlines(keepends=True)
|
||||
current_sub_segment = ""
|
||||
for line in lines:
|
||||
next_sub_segment = current_sub_segment + line
|
||||
|
||||
# 预估下一个子段的大小
|
||||
# 使用一个临时的key(如0)来模拟
|
||||
if get_json_size({0: next_sub_segment}) > chunk_size_max:
|
||||
|
||||
# 如果 current_sub_segment 不为空,才将其添加
|
||||
# 这可以防止因第一行就超限而添加一个空字符串
|
||||
if current_sub_segment:
|
||||
sub_segments.append(current_sub_segment)
|
||||
# 即使单行超限,也必须作为一个独立的子段添加
|
||||
sub_segments.append(line)
|
||||
current_sub_segment = "" # 重置
|
||||
else:
|
||||
chunks.append(chunk)
|
||||
chunk = {key:val}
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
current_sub_segment = next_sub_segment
|
||||
|
||||
# 不要忘记循环结束后剩余的部分
|
||||
if current_sub_segment:
|
||||
sub_segments.append(current_sub_segment)
|
||||
|
||||
# 如果sub_segments为空(例如,原segment为空字符串),则添加一个空字符串以保持一致性
|
||||
if not sub_segments and segment == "":
|
||||
sub_segments.append("")
|
||||
|
||||
start_index = len(new_segments)
|
||||
new_segments.extend(sub_segments)
|
||||
end_index = len(new_segments)
|
||||
# 只有当一个segment被真正分割成多个部分时,才记录
|
||||
if end_index - start_index > 1:
|
||||
merged_indices_list.append((start_index, end_index))
|
||||
else:
|
||||
new_segments.append(segment)
|
||||
|
||||
# === 第二部分:将处理后的 new_segments 组合成 JSON 块 ===
|
||||
json_chunks_list = []
|
||||
if not new_segments: # 处理输入为空列表的边缘情况
|
||||
return {}, [], []
|
||||
|
||||
js={}
|
||||
chunk = {}
|
||||
for key, val in enumerate(new_segments):
|
||||
# 预先构建下一个块的样子来检查大小
|
||||
prospective_chunk = chunk.copy()
|
||||
prospective_chunk[str(key)] = val
|
||||
|
||||
# 检查 prospective_chunk 是否超限,并且当前 chunk 不为空
|
||||
# 如果 chunk 为空,意味着这个 val 本身就超限了,但我们必须接受它,
|
||||
# 因为它已经是最小单位了。这可以防止产生空字典。
|
||||
if get_json_size(prospective_chunk) > chunk_size_max and chunk:
|
||||
json_chunks_list.append(chunk) # 将旧的、未超限的块存入列表
|
||||
chunk = {str(key): val} # 用当前元素开始一个新的块
|
||||
else:
|
||||
chunk = prospective_chunk # 未超限,更新块
|
||||
js[str(key)]=val
|
||||
|
||||
# 循环结束后,将最后一个块加入列表
|
||||
if chunk:
|
||||
json_chunks_list.append(chunk)
|
||||
js.update(chunk)
|
||||
|
||||
return js, json_chunks_list, merged_indices_list
|
||||
|
||||
def get_json_size(js: dict):
|
||||
return len(json.dumps(js,ensure_ascii=False).encode())
|
||||
|
||||
if __name__ == '__main__':
|
||||
js={1:2,3:4,5:"哈哈"}
|
||||
ls=flat_json_split(js,30)
|
||||
print(ls)
|
||||
# for chunk in ls:
|
||||
# print(len(chunk.encode()))
|
||||
print(get_json_size({"0": ""}))
|
||||
|
||||
@@ -14,6 +14,8 @@ dependencies = [
|
||||
"mammoth>=1.10.0",
|
||||
"srt>=3.5.3",
|
||||
"lxml>=5.4.0",
|
||||
"python-docx>=1.2.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
||||
4
uv.lock
generated
4
uv.lock
generated
@@ -315,6 +315,7 @@ wheels = [
|
||||
name = "docutranslate"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "fastapi", extra = ["standard"] },
|
||||
{ name = "httpx" },
|
||||
{ name = "json-repair" },
|
||||
@@ -323,6 +324,7 @@ dependencies = [
|
||||
{ name = "mammoth" },
|
||||
{ name = "markdown2" },
|
||||
{ name = "openpyxl" },
|
||||
{ name = "python-docx" },
|
||||
{ name = "srt" },
|
||||
{ name = "xlsx2html" },
|
||||
]
|
||||
@@ -342,6 +344,7 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "beautifulsoup4", specifier = ">=4.13.4" },
|
||||
{ name = "docling", marker = "extra == 'docling'", specifier = ">=2.40.0" },
|
||||
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.12" },
|
||||
{ name = "httpx", specifier = "==0.27.2" },
|
||||
@@ -352,6 +355,7 @@ requires-dist = [
|
||||
{ name = "markdown2", specifier = ">=2.5.3" },
|
||||
{ name = "opencv-python", marker = "extra == 'docling'", specifier = ">=4.11.0.86" },
|
||||
{ name = "openpyxl", specifier = ">=3.1.5" },
|
||||
{ name = "python-docx", specifier = ">=1.2.0" },
|
||||
{ name = "srt", specifier = ">=3.5.3" },
|
||||
{ name = "xlsx2html", specifier = ">=0.6.2" },
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user