增加provider

This commit is contained in:
xunbu
2025-12-27 23:22:52 +08:00
parent 70a444f2b7
commit 6f4e5195c5

View File

@@ -53,6 +53,7 @@ from pydantic import (
AliasChoices, AliasChoices,
ConfigDict, ConfigDict,
Json, Json,
TypeAdapter, # Added TypeAdapter
) )
from docutranslate import __version__ from docutranslate import __version__
@@ -176,14 +177,43 @@ def _create_default_task_state() -> Dict[str, Any]:
} }
def get_workflow_type_from_filename(filename: str) -> str:
"""根据文件扩展名自动选择 workflow_type"""
ext = Path(filename).suffix.lower()
if ext in [".pdf",".png",".jpg"]:
return "markdown_based"
elif ext in [".md", ".markdown"]:
return "markdown_based"
elif ext in [".docx",".doc"]:
return "docx"
elif ext in [".csv",".xlsx",".xls"]:
return "xlsx"
elif ext in [".pptx","ppt"]:
return "pptx"
elif ext in [".json"]:
return "json"
elif ext in [".srt"]:
return "srt"
elif ext in [".ass"]:
return "ass"
elif ext in [".epub"]:
return "epub"
elif ext in [".html", ".htm"]:
return "html"
elif ext in [".txt"]:
return "txt"
else:
return "txt"
# --- 日志处理器 --- # --- 日志处理器 ---
class QueueAndHistoryHandler(logging.Handler): class QueueAndHistoryHandler(logging.Handler):
def __init__( def __init__(
self, self,
queue_ref: asyncio.Queue, queue_ref: asyncio.Queue,
history_list_ref: List[str], history_list_ref: List[str],
max_history_items: int, max_history_items: int,
task_id: str, task_id: str,
): ):
super().__init__() super().__init__()
self.queue = queue_ref self.queue = queue_ref
@@ -279,7 +309,6 @@ DocuTranslate 后端服务 API提供文档翻译、状态查询、结果下
**版本**: {__version__} **版本**: {__version__}
""", """,
version=__version__, version=__version__,
openapi_tags=tags_metadata,
) )
# mimetypes.add_type("application/wasm", ".wasm") # mimetypes.add_type("application/wasm", ".wasm")
service_router = APIRouter(prefix="/service", tags=["Service API"]) service_router = APIRouter(prefix="/service", tags=["Service API"])
@@ -440,18 +469,26 @@ class BaseWorkflowParams(BaseModel):
if not values.get("skip_translate"): if not values.get("skip_translate"):
# Check for standard keys or their aliases # Check for standard keys or their aliases
if not (values.get("base_url") or values.get("baseurl")): if not (values.get("base_url") or values.get("baseurl")):
raise ValueError( # Auto 模式在校验前不强制要求 base_url
"当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。" if values.get("workflow_type") != "auto":
) raise ValueError(
"当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。"
)
if not values.get("model_id"): if not values.get("model_id"):
raise ValueError( if values.get("workflow_type") != "auto":
"当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。" raise ValueError(
) "当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。"
)
# 如果跳过翻译,则不进行任何检查,允许 base_url 等字段为空 # 如果跳过翻译,则不进行任何检查,允许 base_url 等字段为空
return values return values
# 2. 为每个工作流创建独立的参数模型 # 2. 为每个工作流创建独立的参数模型
class AutoWorkflowParams(BaseWorkflowParams):
workflow_type: Literal["auto"] = Field(..., description="根据文件后缀自动选择工作流。")
model_config = ConfigDict(extra='allow')
class MarkdownWorkflowParams(BaseWorkflowParams): class MarkdownWorkflowParams(BaseWorkflowParams):
workflow_type: Literal["markdown_based"] = Field( workflow_type: Literal["markdown_based"] = Field(
..., description="指定使用基于Markdown的翻译工作流。" ..., description="指定使用基于Markdown的翻译工作流。"
@@ -622,7 +659,7 @@ class HtmlWorkflowParams(BaseWorkflowParams):
) )
insert_mode: Literal["replace", "append", "prepend"] = Field( insert_mode: Literal["replace", "append", "prepend"] = Field(
"replace", "replace",
description="翻译文本的插入模式。'replace':替换原文,'append'附加到原文后,'prepend':附加到原文前。", description="翻译文本的插入模式。'replace':替换原文,'append' :附加到原文后,'prepend':附加到原文前。",
) )
separator: str = Field( separator: str = Field(
" ", " ",
@@ -673,6 +710,7 @@ class PPTXWorkflowParams(BaseWorkflowParams):
# 3. 使用可辨识联合类型Discriminated Union将它们组合起来 # 3. 使用可辨识联合类型Discriminated Union将它们组合起来
TranslatePayload = Annotated[ TranslatePayload = Annotated[
Union[ Union[
AutoWorkflowParams,
MarkdownWorkflowParams, MarkdownWorkflowParams,
TextWorkflowParams, TextWorkflowParams,
JsonWorkflowParams, JsonWorkflowParams,
@@ -714,6 +752,17 @@ class TranslateServiceRequest(BaseModel):
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
"examples": [ "examples": [
{
"file_name": "auto_detect_doc.pdf",
"file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...",
"payload": {
"workflow_type": "auto",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-your-api-key-here",
"model_id": "gpt-4o",
"to_lang": "中文",
},
},
{ {
"file_name": "annual_report_203.pdf", "file_name": "annual_report_203.pdf",
"file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...", "file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...",
@@ -956,10 +1005,10 @@ class TranslateServiceRequest(BaseModel):
# --- Background Task Logic --- # --- Background Task Logic ---
async def _perform_translation( async def _perform_translation(
task_id: str, task_id: str,
payload: TranslatePayload, payload: TranslatePayload,
file_contents: bytes, file_contents: bytes,
original_filename: str, original_filename: str,
): ):
task_state = tasks_state[task_id] task_state = tasks_state[task_id]
log_queue = tasks_log_queues[task_id] log_queue = tasks_log_queues[task_id]
@@ -1462,7 +1511,6 @@ async def _perform_translation(
# 定义导出函数映射 # 定义导出函数映射
export_map = {} export_map = {}
if isinstance(workflow, MDFormatsExportable): if isinstance(workflow, MDFormatsExportable):
export_map["markdown"] = ( export_map["markdown"] = (
workflow.export_to_markdown, workflow.export_to_markdown,
@@ -1660,11 +1708,37 @@ async def _perform_translation(
# --- 核心任务启动逻辑 --- # --- 核心任务启动逻辑 ---
async def _start_translation_task( async def _start_translation_task(
task_id: str, task_id: str,
payload: TranslatePayload, payload: TranslatePayload,
file_contents: bytes, file_contents: bytes,
original_filename: str, original_filename: str,
): ):
# --- 新增: Auto 工作流路由逻辑 ---
if payload.workflow_type == "auto":
detected_type = get_workflow_type_from_filename(original_filename)
print(f"[{task_id}] 自动识别工作流: {original_filename} -> {detected_type}")
# 将参数转换为目标具体工作流类型所需的字典
payload_data = payload.model_dump()
payload_data["workflow_type"] = detected_type
# 针对特定格式的默认策略
if detected_type == "json" and not payload_data.get("json_paths"):
payload_data["json_paths"] = ["$..*"] # 默认翻译所有内容
if detected_type == "markdown_based" and not payload_data.get("convert_engine"):
if Path(original_filename).suffix.lower() == ".pdf":
payload_data["convert_engine"] = "mineru" if not DOCLING_EXIST else "docling"
else:
payload_data["convert_engine"] = "identity"
# 重新校验为具体的 Payload 类型
try:
payload = TypeAdapter(TranslatePayload).validate_python(payload_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"自动转换工作流参数失败: {e}")
# -----------------------------
if task_id not in tasks_state: if task_id not in tasks_state:
tasks_state[task_id] = _create_default_task_state() tasks_state[task_id] = _create_default_task_state()
tasks_log_queues[task_id] = asyncio.Queue() tasks_log_queues[task_id] = asyncio.Queue()
@@ -1672,9 +1746,9 @@ async def _start_translation_task(
task_state = tasks_state[task_id] task_state = tasks_state[task_id]
if ( if (
task_state["is_processing"] task_state["is_processing"]
and task_state["current_task_ref"] and task_state["current_task_ref"]
and not task_state["current_task_ref"].done() and not task_state["current_task_ref"].done()
): ):
raise HTTPException( raise HTTPException(
status_code=429, detail=f"任务ID '{task_id}' 正在进行中,请稍后再试。" status_code=429, detail=f"任务ID '{task_id}' 正在进行中,请稍后再试。"
@@ -1694,7 +1768,7 @@ async def _start_translation_task(
"error_flag": False, "error_flag": False,
"download_ready": False, "download_ready": False,
"workflow_instance": None, "workflow_instance": None,
"original_filename_stem": safe_stem, # 存入安全的stem "original_filename_stem": safe_stem, # 存入安全的stem
"original_filename": original_filename, "original_filename": original_filename,
"task_start_time": time.time(), "task_start_time": time.time(),
"task_end_time": 0, "task_end_time": 0,
@@ -1747,9 +1821,9 @@ def _cancel_translation_logic(task_id: str):
if not task_state: if not task_state:
raise HTTPException(status_code=404, detail=f"找不到任务ID '{task_id}'") raise HTTPException(status_code=404, detail=f"找不到任务ID '{task_id}'")
if ( if (
not task_state not task_state
or not task_state["is_processing"] or not task_state["is_processing"]
or not task_state["current_task_ref"] or not task_state["current_task_ref"]
): ):
raise HTTPException( raise HTTPException(
status_code=400, detail=f"任务ID '{task_id}' 没有正在进行的翻译任务可取消。" status_code=400, detail=f"任务ID '{task_id}' 没有正在进行的翻译任务可取消。"
@@ -1778,7 +1852,8 @@ def _cancel_translation_logic(task_id: str):
description=""" description="""
接收一个包含文件内容Base64编码和工作流参数的JSON请求启动一个后台翻译任务。 接收一个包含文件内容Base64编码和工作流参数的JSON请求启动一个后台翻译任务。
- **工作流选择**: 请求体中的 `payload.workflow_type` 字段决定了本次任务类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`)。 - **工作流选择**: `payload.workflow_type` 决定任务类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`, `auto`)。
- **Auto 模式**: 当设置为 `auto` 时,后端将根据 `file_name` 的扩展名自动选择最合适的工作流。
- **动态参数**: 根据所选工作流API需要不同的参数集。请参考下面的Schema或示例。 - **动态参数**: 根据所选工作流API需要不同的参数集。请参考下面的Schema或示例。
- **异步处理**: 此端点会立即返回任务ID客户端需轮询状态接口获取进度。 - **异步处理**: 此端点会立即返回任务ID客户端需轮询状态接口获取进度。
""", """,
@@ -1803,9 +1878,9 @@ def _cancel_translation_logic(task_id: str):
}, },
) )
async def service_translate( async def service_translate(
request: TranslateServiceRequest = Body( request: TranslateServiceRequest = Body(
..., description="翻译任务的详细参数和文件内容。" ..., description="翻译任务的详细参数和文件内容。"
) )
): ):
task_id = uuid.uuid4().hex[:8] task_id = uuid.uuid4().hex[:8]
@@ -1840,13 +1915,30 @@ async def service_translate(
"/translate/file", "/translate/file",
summary="提交翻译任务 (文件上传)", summary="提交翻译任务 (文件上传)",
description=""" description="""
接收一个上传文件和包含工作流参数的JSON字符串启动一个后台翻译任务。 通过 `multipart/form-data` 方式上传文件并启动翻译任务。
- **工作流选择**: `payload` 表单字段中的 `workflow_type` 字段决定了本次任务的类型 此接口适用于直接上传二进制文件(如 PDF, Docx 等),无需先进行 Base64 编码
- **文件上传**: 通过 `file` 字段上传文件替代JSON接口中的 `file_content` 和 `file_name`。
- **参数传递**: `payload` 字段应为一个符合 JSON 格式的字符串,其结构与 `/service/translate` 中的 `payload` 字段完全一致。 ### 参数说明
- **异步处理**: 此端点会立即返回任务ID客户端需轮询状态接口获取进度 - **file**: (必须) 要翻译的二进制文件
""", - **payload**: (必须) 包含工作流配置的 **JSON 字符串**。
- 必须包含 `workflow_type` (如 `auto`, `docx`, `markdown_based` 等)。
- 其他参数根据 `workflow_type` 不同而变化 (详见 `TranslatePayload` 模型)。
### Payload 示例 (JSON String)
```json
{
"workflow_type": "auto",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-xxxxxx",
"model_id": "gpt-4o",
"to_lang": "中文"
}
```
### 响应
返回包含 `task_id` 的 JSON 对象。客户端需使用此 ID 轮询 `/service/status/{task_id}` 接口获取进度。
""",
responses={ responses={
200: { 200: {
"description": "翻译任务已成功启动。", "description": "翻译任务已成功启动。",
@@ -1868,10 +1960,10 @@ async def service_translate(
}, },
) )
async def service_translate_file( async def service_translate_file(
file: UploadFile = File(..., description="要翻译的文件"), file: UploadFile = File(..., description="要翻译的文件"),
payload: Json[TranslatePayload] = Form( payload: Json[TranslatePayload] = Form(
..., description="包含工作流参数的JSON字符串结构与JSON接口的payload一致" ..., description="包含工作流参数的JSON字符串 (详见接口文档说明)"
), ),
): ):
task_id = uuid.uuid4().hex[:8] task_id = uuid.uuid4().hex[:8]
@@ -1925,9 +2017,9 @@ async def service_release_task(task_id: str):
task_state = tasks_state.get(task_id) task_state = tasks_state.get(task_id)
message_parts = [] message_parts = []
if ( if (
task_state task_state
and task_state.get("is_processing") and task_state.get("is_processing")
and task_state.get("current_task_ref") and task_state.get("current_task_ref")
): ):
try: try:
print(f"[{task_id}] 任务正在进行中,将在释放前尝试取消。") print(f"[{task_id}] 任务正在进行中,将在释放前尝试取消。")
@@ -2167,9 +2259,9 @@ async def service_release_task(task_id: str):
}, },
) )
async def service_get_status( async def service_get_status(
task_id: str = FastApiPath( task_id: str = FastApiPath(
..., description="要查询状态的任务的ID", examples=["b2865b93"] ..., description="要查询状态的任务的ID", examples=["b2865b93"]
) )
): ):
task_state = tasks_state.get(task_id) task_state = tasks_state.get(task_id)
if not task_state: if not task_state:
@@ -2273,14 +2365,14 @@ FileType = Literal[
}, },
) )
async def service_download_file( async def service_download_file(
task_id: str = FastApiPath( task_id: str = FastApiPath(
..., description="已完成任务的ID", examples=["b2865b93"] ..., description="已完成任务的ID", examples=["b2865b93"]
), ),
file_type: FileType = FastApiPath( file_type: FileType = FastApiPath(
..., ...,
description="要下载的文件类型。", description="要下载的文件类型。",
examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"], examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"],
), ),
): ):
task_state = tasks_state.get(task_id) task_state = tasks_state.get(task_id)
if not task_state: if not task_state:
@@ -2319,12 +2411,12 @@ async def service_download_file(
}, },
) )
async def service_download_attachment( async def service_download_attachment(
task_id: str = FastApiPath( task_id: str = FastApiPath(
..., description="已完成任务的ID", examples=["g1h2i3j4"] ..., description="已完成任务的ID", examples=["g1h2i3j4"]
), ),
identifier: str = FastApiPath( identifier: str = FastApiPath(
..., description="要下载的附件的标识符。", examples=["glossary"] ..., description="要下载的附件的标识符。", examples=["glossary"]
), ),
): ):
task_state = tasks_state.get(task_id) task_state = tasks_state.get(task_id)
if not task_state: if not task_state:
@@ -2404,14 +2496,14 @@ async def service_download_attachment(
}, },
) )
async def service_content( async def service_content(
task_id: str = FastApiPath( task_id: str = FastApiPath(
..., description="已完成任务的ID", examples=["b2865b93"] ..., description="已完成任务的ID", examples=["b2865b93"]
), ),
file_type: FileType = FastApiPath( file_type: FileType = FastApiPath(
..., ...,
description="要获取内容的文件类型。", description="要获取内容的文件类型。",
examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"], examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"],
), ),
): ):
task_state = tasks_state.get(task_id) task_state = tasks_state.get(task_id)
if not task_state: if not task_state:
@@ -2527,22 +2619,22 @@ async def redoc_html():
@app.post("/temp/translate", tags=["Temp"]) @app.post("/temp/translate", tags=["Temp"])
async def temp_translate( async def temp_translate(
base_url: str = Body(...), base_url: str = Body(...),
api_key: str = Body("xx"), api_key: str = Body("xx"),
model_id: str = Body(...), model_id: str = Body(...),
mineru_token: Optional[str] = Body(None), mineru_token: Optional[str] = Body(None),
file_name: str = Body(...), file_name: str = Body(...),
file_content: str = Body(...), file_content: str = Body(...),
to_lang: str = Body("中文"), to_lang: str = Body("中文"),
concurrent: int = Body(default_params["concurrent"]), concurrent: int = Body(default_params["concurrent"]),
temperature: float = Body(default_params["temperature"]), temperature: float = Body(default_params["temperature"]),
thinking: ThinkingMode = Body(default_params["thinking"]), thinking: ThinkingMode = Body(default_params["thinking"]),
chunk_size: int = Body(default_params["chunk_size"]), chunk_size: int = Body(default_params["chunk_size"]),
custom_prompt: Optional[str] = Body(None), custom_prompt: Optional[str] = Body(None),
model_version: Literal["pipeline", "vlm"] = Body("vlm"), model_version: Literal["pipeline", "vlm"] = Body("vlm"),
glossary_dict: Optional[Dict[str, str]] = Body(None), glossary_dict: Optional[Dict[str, str]] = Body(None),
rpm: Optional[int] = Body(None), rpm: Optional[int] = Body(None),
tpm: Optional[int] = Body(None), tpm: Optional[int] = Body(None),
): ):
file_name = Path(file_name) file_name = Path(file_name)
try: try:
@@ -2594,7 +2686,8 @@ def find_free_port(start_port):
port += 1 port += 1
def run_app(host=None,port: int | None = None,enable_CORS=False,allow_origin_regex=r"^(https?://.*|null|file://.*)$"): def run_app(host=None, port: int | None = None, enable_CORS=False,
allow_origin_regex=r"^(https?://.*|null|file://.*)$"):
initial_port = port or int(os.environ.get("DOCUTRANSLATE_PORT", 8010)) initial_port = port or int(os.environ.get("DOCUTRANSLATE_PORT", 8010))
try: try:
port_to_use = find_free_port(initial_port) port_to_use = find_free_port(initial_port)