v0.0.1
This commit is contained in:
0
filetranslate/utils/__init__.py
Normal file
0
filetranslate/utils/__init__.py
Normal file
109
filetranslate/utils/agent_utils.py
Normal file
109
filetranslate/utils/agent_utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, baseurl="", key="", model_id="", system_prompt="", temperature=0.7, max_concurrent=5):
|
||||
self.baseurl = baseurl
|
||||
self.key = key
|
||||
self.model_id = model_id
|
||||
self.system_prompt = system_prompt
|
||||
self.temperature = temperature
|
||||
# self.client=httpx.Client()
|
||||
self.client_async = httpx.AsyncClient()
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
def _prepare_request_data(self, prompt, system_prompt, temperature=None, top_p=0.9):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
headers = {"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}"}
|
||||
data = {
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
return headers, data
|
||||
|
||||
# def send_prompt(self,prompt,system_prompt=None,timeout=50):
|
||||
# if system_prompt is None:
|
||||
# system_prompt=self.system_prompt
|
||||
# headers,data=self._prepare_request_data(prompt,system_prompt)
|
||||
# response=self.client.post(f"{self.baseurl}/chat/completions",json=data,headers=headers,timeout=timeout)
|
||||
# response.raise_for_status()
|
||||
# return response.json()["choices"][0]["message"]["content"].lstrip()
|
||||
|
||||
async def send_async(self, prompt: str, system_prompt: None | str = None, timeout: int = 200) -> str:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
"""Sends a single prompt asynchronously."""
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
try:
|
||||
response = await self.client_async.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
result: str = response.json()["choices"][0]["message"]["content"]
|
||||
return result.lstrip()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise Exception(f"AI请求连接错误 (async): {e}") from e
|
||||
except (KeyError, IndexError) as e:
|
||||
raise Exception(f"AI响应格式错误 (async): {e}") from e
|
||||
|
||||
async def send_prompts_async(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
timeout: int = 50,
|
||||
max_concurrent: int = 5 # 新增参数,默认并发数为5
|
||||
) -> list[str]:
|
||||
total = len(prompts)
|
||||
count = 0
|
||||
"""
|
||||
Sends multiple prompts asynchronously, limiting concurrent requests.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
tasks = []
|
||||
|
||||
# 辅助协程,用于包装 self.send_async 并使用信号量
|
||||
async def send_with_semaphore(p_text: str):
|
||||
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
||||
result = await self.send_async(
|
||||
prompt=p_text,
|
||||
system_prompt=system_prompt,
|
||||
timeout=timeout
|
||||
)
|
||||
nonlocal count
|
||||
count += 1
|
||||
print(f"进行到{count}/{total}")
|
||||
return result
|
||||
|
||||
for p_text in prompts:
|
||||
task = asyncio.create_task(send_with_semaphore(p_text))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return results
|
||||
|
||||
def send_prompts(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
timeout: int = 50,
|
||||
) -> list[str]:
|
||||
result = asyncio.run(self.send_prompts_async(prompts, system_prompt, timeout, self.max_concurrent))
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
25
filetranslate/utils/convert.py
Normal file
25
filetranslate/utils/convert.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling_core.types.doc import ImageRefMode
|
||||
from pathlib import Path
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
|
||||
IMAGE_RESOLUTION_SCALE = 4
|
||||
|
||||
|
||||
def pdf2markdown_embed_images(pdf: Path | str, formula=False, code=False) -> str:
|
||||
pipeline_options = PdfPipelineOptions()
|
||||
pipeline_options.images_scale = IMAGE_RESOLUTION_SCALE
|
||||
pipeline_options.generate_picture_images = True
|
||||
if formula:
|
||||
pipeline_options.do_formula_enrichment=True
|
||||
if code:
|
||||
pipeline_options.do_code_enrichment=True
|
||||
converter = DocumentConverter(format_options={
|
||||
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
|
||||
})
|
||||
result = converter.convert(pdf).document.export_to_markdown( image_mode=ImageRefMode.EMBEDDED)
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
200
filetranslate/utils/markdown_splitter.py
Normal file
200
filetranslate/utils/markdown_splitter.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
class MarkdownBlockSplitter:
|
||||
def __init__(self, max_block_size: int = 4096):
|
||||
"""
|
||||
初始化MarkdownBlockSplitter。
|
||||
|
||||
参数:
|
||||
max_block_size: 每个块的最大大小(以字符为单位)。
|
||||
"""
|
||||
self.max_block_size = max_block_size
|
||||
|
||||
def split_markdown(self, markdown_text: str) -> List[str]:
|
||||
"""
|
||||
将markdown文本拆分为指定最大大小的块。
|
||||
|
||||
参数:
|
||||
markdown_text: 输入的markdown文本。
|
||||
|
||||
返回:
|
||||
列表形式的markdown块,每个都是一个字符串。
|
||||
"""
|
||||
# 使用更简单的方法:按Markdown块拆分
|
||||
# 这比使用AST解析更可靠
|
||||
|
||||
# 模式用于识别markdown块(标题、段落、代码块等)
|
||||
blocks = self._split_into_logical_blocks(markdown_text)
|
||||
|
||||
# 现在合并块以遵守max_block_size
|
||||
result_blocks = []
|
||||
current_block = ""
|
||||
|
||||
for block in blocks:
|
||||
# 如果单个块大于最大大小,则进一步拆分
|
||||
if len(block) > self.max_block_size:
|
||||
# 如果已有累积内容,先添加
|
||||
if current_block:
|
||||
result_blocks.append(current_block)
|
||||
current_block = ""
|
||||
|
||||
# 拆分大块
|
||||
large_block_parts = self._split_large_block(block)
|
||||
result_blocks.extend(large_block_parts)
|
||||
continue
|
||||
|
||||
# 如果添加此块会超过限制,则开始新的结果块
|
||||
if len(current_block) + len(block) + 2 > self.max_block_size and current_block:
|
||||
result_blocks.append(current_block)
|
||||
current_block = block
|
||||
else:
|
||||
# 添加到当前块并适当换行
|
||||
if current_block:
|
||||
current_block += "\n\n" + block
|
||||
else:
|
||||
current_block = block
|
||||
|
||||
# 如果不为空则添加最后一个块
|
||||
if current_block:
|
||||
result_blocks.append(current_block)
|
||||
|
||||
return result_blocks
|
||||
|
||||
def _split_into_logical_blocks(self, markdown_text: str) -> List[str]:
|
||||
"""
|
||||
将markdown文本拆分为逻辑块(标题、段落、代码块等)
|
||||
|
||||
参数:
|
||||
markdown_text: 输入markdown文本
|
||||
|
||||
返回:
|
||||
markdown块列表
|
||||
"""
|
||||
# 将Windows换行符替换为Unix风格
|
||||
markdown_text = markdown_text.replace('\r\n', '\n')
|
||||
|
||||
# 匹配代码块的模式(用```或~~~围起来)
|
||||
code_block_pattern = r'(```[\s\S]*?```|~~~[\s\S]*?~~~)'
|
||||
|
||||
# 将文本拆分为代码块和非代码块
|
||||
parts = re.split(code_block_pattern, markdown_text)
|
||||
|
||||
blocks = []
|
||||
for i, part in enumerate(parts):
|
||||
# 如果是代码块(拆分结果中的奇数索引)
|
||||
if i % 2 == 1:
|
||||
blocks.append(part)
|
||||
else:
|
||||
# 对于非代码块,按空行拆分
|
||||
part_blocks = re.split(r'\n\s*\n', part)
|
||||
blocks.extend([b.strip() for b in part_blocks if b.strip()])
|
||||
|
||||
return blocks
|
||||
|
||||
def _split_large_block(self, block: str) -> List[str]:
|
||||
"""
|
||||
拆分超过max_block_size的大块。
|
||||
|
||||
参数:
|
||||
block: 一个大的markdown块
|
||||
|
||||
返回:
|
||||
较小的块列表
|
||||
"""
|
||||
result = []
|
||||
|
||||
# 检查是否是代码块
|
||||
if block.startswith('```') or block.startswith('~~~'):
|
||||
# 对于代码块,我们需要保留围栏标记
|
||||
fence_marker = '```' if block.startswith('```') else '~~~'
|
||||
|
||||
# 提取语言说明符(如果存在)
|
||||
first_line_end = block.find('\n')
|
||||
first_line = block[:first_line_end]
|
||||
language_spec = first_line[3:].strip()
|
||||
|
||||
# 拆分代码内容
|
||||
code_content = block[first_line_end + 1:-3].strip()
|
||||
|
||||
# 按行拆分
|
||||
lines = code_content.split('\n')
|
||||
|
||||
current_part = [first_line]
|
||||
current_size = len(first_line) + 1 # +1表示换行符
|
||||
|
||||
for line in lines:
|
||||
line_size = len(line) + 1 # +1表示换行符
|
||||
|
||||
if current_size + line_size + 3 > self.max_block_size: # +3表示关闭围栏
|
||||
# 关闭当前代码块
|
||||
current_part.append(fence_marker)
|
||||
result.append('\n'.join(current_part))
|
||||
|
||||
# 开始新的代码块
|
||||
current_part = [f"{fence_marker}{language_spec}"]
|
||||
current_size = len(current_part[0]) + 1
|
||||
|
||||
current_part.append(line)
|
||||
current_size += line_size
|
||||
|
||||
# 在最后部分添加关闭围栏
|
||||
current_part.append(fence_marker)
|
||||
result.append('\n'.join(current_part))
|
||||
|
||||
else:
|
||||
# 对于其他块,按句子或行拆分
|
||||
if '.' in block or '!' in block or '?' in block:
|
||||
# 按句子拆分
|
||||
sentences = re.split(r'(?<=[.!?])\s+', block)
|
||||
|
||||
current_part = []
|
||||
current_size = 0
|
||||
|
||||
for sentence in sentences:
|
||||
if current_size + len(sentence) + 1 > self.max_block_size and current_part:
|
||||
result.append(' '.join(current_part))
|
||||
current_part = [sentence]
|
||||
current_size = len(sentence)
|
||||
else:
|
||||
current_part.append(sentence)
|
||||
current_size += len(sentence) + 1 # +1表示空格
|
||||
|
||||
if current_part:
|
||||
result.append(' '.join(current_part))
|
||||
else:
|
||||
# 按行拆分
|
||||
lines = block.split('\n')
|
||||
|
||||
current_part = []
|
||||
current_size = 0
|
||||
|
||||
for line in lines:
|
||||
if current_size + len(line) + 1 > self.max_block_size and current_part:
|
||||
result.append('\n'.join(current_part))
|
||||
current_part = [line]
|
||||
current_size = len(line)
|
||||
else:
|
||||
current_part.append(line)
|
||||
current_size += len(line) + 1 # +1表示换行符
|
||||
|
||||
if current_part:
|
||||
result.append('\n'.join(current_part))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def split_markdown_text(markdown_text, max_block_size=4096):
|
||||
"""
|
||||
将markdown字符串拆分为不超过max_block_size的块。
|
||||
|
||||
参数:
|
||||
markdown_text: 输入markdown文本
|
||||
max_block_size: 每个块的最大字符数
|
||||
|
||||
返回:
|
||||
markdown块列表
|
||||
"""
|
||||
splitter = MarkdownBlockSplitter(max_block_size=max_block_size)
|
||||
return splitter.split_markdown(markdown_text)
|
||||
61
filetranslate/utils/markdown_utils.py
Normal file
61
filetranslate/utils/markdown_utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
|
||||
|
||||
class MaskDict:
|
||||
def __init__(self):
|
||||
self._dict = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def create_id(self):
|
||||
with self._lock:
|
||||
while True:
|
||||
id = uuid.uuid1().hex[:6]
|
||||
if id not in self._dict:
|
||||
return id
|
||||
|
||||
def get(self, key):
|
||||
with self._lock:
|
||||
return self._dict.get(key)
|
||||
|
||||
def set(self, key, value):
|
||||
with self._lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def delete(self, key):
|
||||
with self._lock:
|
||||
if key in self._dict:
|
||||
del self._dict[key]
|
||||
|
||||
def __contains__(self, item):
|
||||
with self._lock:
|
||||
return item in self._dict
|
||||
def uris2placeholder(markdown:str, mask_dict:MaskDict):
|
||||
def uri2placeholder(match: re.Match):
|
||||
id = mask_dict.create_id()
|
||||
mask_dict.set(id, match.group())
|
||||
return f"<ph-{id}>"
|
||||
|
||||
uri_pattern = r'!?\[.*?\]\(.*?\)'
|
||||
markdown = re.sub(uri_pattern, uri2placeholder, markdown)
|
||||
return markdown
|
||||
|
||||
def placeholder2_uris(markdown:str, mask_dict:MaskDict):
|
||||
def placeholder2uri(match:re.Match):
|
||||
id=match.group(1)
|
||||
uri=mask_dict.get(id)
|
||||
if uri is None:
|
||||
return match.group()
|
||||
return uri
|
||||
|
||||
ph_pattern = r"<ph-([a-zA-Z0-9]+)>"
|
||||
markdown = re.sub(ph_pattern, placeholder2uri, markdown)
|
||||
return markdown
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
Reference in New Issue
Block a user