增加provider

This commit is contained in:
xunbu
2025-12-27 23:22:52 +08:00
parent 70a444f2b7
commit 6f4e5195c5

View File

@@ -53,6 +53,7 @@ from pydantic import (
AliasChoices, AliasChoices,
ConfigDict, ConfigDict,
Json, Json,
TypeAdapter, # Added TypeAdapter
) )
from docutranslate import __version__ from docutranslate import __version__
@@ -176,6 +177,35 @@ def _create_default_task_state() -> Dict[str, Any]:
} }
def get_workflow_type_from_filename(filename: str) -> str:
"""根据文件扩展名自动选择 workflow_type"""
ext = Path(filename).suffix.lower()
if ext in [".pdf",".png",".jpg"]:
return "markdown_based"
elif ext in [".md", ".markdown"]:
return "markdown_based"
elif ext in [".docx",".doc"]:
return "docx"
elif ext in [".csv",".xlsx",".xls"]:
return "xlsx"
elif ext in [".pptx","ppt"]:
return "pptx"
elif ext in [".json"]:
return "json"
elif ext in [".srt"]:
return "srt"
elif ext in [".ass"]:
return "ass"
elif ext in [".epub"]:
return "epub"
elif ext in [".html", ".htm"]:
return "html"
elif ext in [".txt"]:
return "txt"
else:
return "txt"
# --- 日志处理器 --- # --- 日志处理器 ---
class QueueAndHistoryHandler(logging.Handler): class QueueAndHistoryHandler(logging.Handler):
def __init__( def __init__(
@@ -279,7 +309,6 @@ DocuTranslate 后端服务 API提供文档翻译、状态查询、结果下
**版本**: {__version__} **版本**: {__version__}
""", """,
version=__version__, version=__version__,
openapi_tags=tags_metadata,
) )
# mimetypes.add_type("application/wasm", ".wasm") # mimetypes.add_type("application/wasm", ".wasm")
service_router = APIRouter(prefix="/service", tags=["Service API"]) service_router = APIRouter(prefix="/service", tags=["Service API"])
@@ -440,10 +469,13 @@ class BaseWorkflowParams(BaseModel):
if not values.get("skip_translate"): if not values.get("skip_translate"):
# Check for standard keys or their aliases # Check for standard keys or their aliases
if not (values.get("base_url") or values.get("baseurl")): if not (values.get("base_url") or values.get("baseurl")):
# Auto 模式在校验前不强制要求 base_url
if values.get("workflow_type") != "auto":
raise ValueError( raise ValueError(
"当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。" "当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。"
) )
if not values.get("model_id"): if not values.get("model_id"):
if values.get("workflow_type") != "auto":
raise ValueError( raise ValueError(
"当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。" "当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。"
) )
@@ -452,6 +484,11 @@ class BaseWorkflowParams(BaseModel):
# 2. 为每个工作流创建独立的参数模型 # 2. 为每个工作流创建独立的参数模型
class AutoWorkflowParams(BaseWorkflowParams):
workflow_type: Literal["auto"] = Field(..., description="根据文件后缀自动选择工作流。")
model_config = ConfigDict(extra='allow')
class MarkdownWorkflowParams(BaseWorkflowParams): class MarkdownWorkflowParams(BaseWorkflowParams):
workflow_type: Literal["markdown_based"] = Field( workflow_type: Literal["markdown_based"] = Field(
..., description="指定使用基于Markdown的翻译工作流。" ..., description="指定使用基于Markdown的翻译工作流。"
@@ -622,7 +659,7 @@ class HtmlWorkflowParams(BaseWorkflowParams):
) )
insert_mode: Literal["replace", "append", "prepend"] = Field( insert_mode: Literal["replace", "append", "prepend"] = Field(
"replace", "replace",
description="翻译文本的插入模式。'replace':替换原文,'append'附加到原文后,'prepend':附加到原文前。", description="翻译文本的插入模式。'replace':替换原文,'append' :附加到原文后,'prepend':附加到原文前。",
) )
separator: str = Field( separator: str = Field(
" ", " ",
@@ -673,6 +710,7 @@ class PPTXWorkflowParams(BaseWorkflowParams):
# 3. 使用可辨识联合类型Discriminated Union将它们组合起来 # 3. 使用可辨识联合类型Discriminated Union将它们组合起来
TranslatePayload = Annotated[ TranslatePayload = Annotated[
Union[ Union[
AutoWorkflowParams,
MarkdownWorkflowParams, MarkdownWorkflowParams,
TextWorkflowParams, TextWorkflowParams,
JsonWorkflowParams, JsonWorkflowParams,
@@ -714,6 +752,17 @@ class TranslateServiceRequest(BaseModel):
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
"examples": [ "examples": [
{
"file_name": "auto_detect_doc.pdf",
"file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...",
"payload": {
"workflow_type": "auto",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-your-api-key-here",
"model_id": "gpt-4o",
"to_lang": "中文",
},
},
{ {
"file_name": "annual_report_203.pdf", "file_name": "annual_report_203.pdf",
"file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...", "file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...",
@@ -1462,7 +1511,6 @@ async def _perform_translation(
# 定义导出函数映射 # 定义导出函数映射
export_map = {} export_map = {}
if isinstance(workflow, MDFormatsExportable): if isinstance(workflow, MDFormatsExportable):
export_map["markdown"] = ( export_map["markdown"] = (
workflow.export_to_markdown, workflow.export_to_markdown,
@@ -1665,6 +1713,32 @@ async def _start_translation_task(
file_contents: bytes, file_contents: bytes,
original_filename: str, original_filename: str,
): ):
# --- 新增: Auto 工作流路由逻辑 ---
if payload.workflow_type == "auto":
detected_type = get_workflow_type_from_filename(original_filename)
print(f"[{task_id}] 自动识别工作流: {original_filename} -> {detected_type}")
# 将参数转换为目标具体工作流类型所需的字典
payload_data = payload.model_dump()
payload_data["workflow_type"] = detected_type
# 针对特定格式的默认策略
if detected_type == "json" and not payload_data.get("json_paths"):
payload_data["json_paths"] = ["$..*"] # 默认翻译所有内容
if detected_type == "markdown_based" and not payload_data.get("convert_engine"):
if Path(original_filename).suffix.lower() == ".pdf":
payload_data["convert_engine"] = "mineru" if not DOCLING_EXIST else "docling"
else:
payload_data["convert_engine"] = "identity"
# 重新校验为具体的 Payload 类型
try:
payload = TypeAdapter(TranslatePayload).validate_python(payload_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"自动转换工作流参数失败: {e}")
# -----------------------------
if task_id not in tasks_state: if task_id not in tasks_state:
tasks_state[task_id] = _create_default_task_state() tasks_state[task_id] = _create_default_task_state()
tasks_log_queues[task_id] = asyncio.Queue() tasks_log_queues[task_id] = asyncio.Queue()
@@ -1778,7 +1852,8 @@ def _cancel_translation_logic(task_id: str):
description=""" description="""
接收一个包含文件内容Base64编码和工作流参数的JSON请求启动一个后台翻译任务。 接收一个包含文件内容Base64编码和工作流参数的JSON请求启动一个后台翻译任务。
- **工作流选择**: 请求体中的 `payload.workflow_type` 字段决定了本次任务类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`)。 - **工作流选择**: `payload.workflow_type` 决定任务类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`, `auto`)。
- **Auto 模式**: 当设置为 `auto` 时,后端将根据 `file_name` 的扩展名自动选择最合适的工作流。
- **动态参数**: 根据所选工作流API需要不同的参数集。请参考下面的Schema或示例。 - **动态参数**: 根据所选工作流API需要不同的参数集。请参考下面的Schema或示例。
- **异步处理**: 此端点会立即返回任务ID客户端需轮询状态接口获取进度。 - **异步处理**: 此端点会立即返回任务ID客户端需轮询状态接口获取进度。
""", """,
@@ -1840,12 +1915,29 @@ async def service_translate(
"/translate/file", "/translate/file",
summary="提交翻译任务 (文件上传)", summary="提交翻译任务 (文件上传)",
description=""" description="""
接收一个上传文件和包含工作流参数的JSON字符串启动一个后台翻译任务。 通过 `multipart/form-data` 方式上传文件并启动翻译任务。
- **工作流选择**: `payload` 表单字段中的 `workflow_type` 字段决定了本次任务的类型 此接口适用于直接上传二进制文件(如 PDF, Docx 等),无需先进行 Base64 编码
- **文件上传**: 通过 `file` 字段上传文件替代JSON接口中的 `file_content` 和 `file_name`。
- **参数传递**: `payload` 字段应为一个符合 JSON 格式的字符串,其结构与 `/service/translate` 中的 `payload` 字段完全一致。 ### 参数说明
- **异步处理**: 此端点会立即返回任务ID客户端需轮询状态接口获取进度 - **file**: (必须) 要翻译的二进制文件
- **payload**: (必须) 包含工作流配置的 **JSON 字符串**。
- 必须包含 `workflow_type` (如 `auto`, `docx`, `markdown_based` 等)。
- 其他参数根据 `workflow_type` 不同而变化 (详见 `TranslatePayload` 模型)。
### Payload 示例 (JSON String)
```json
{
"workflow_type": "auto",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-xxxxxx",
"model_id": "gpt-4o",
"to_lang": "中文"
}
```
### 响应
返回包含 `task_id` 的 JSON 对象。客户端需使用此 ID 轮询 `/service/status/{task_id}` 接口获取进度。
""", """,
responses={ responses={
200: { 200: {
@@ -1870,7 +1962,7 @@ async def service_translate(
async def service_translate_file( async def service_translate_file(
file: UploadFile = File(..., description="要翻译的文件"), file: UploadFile = File(..., description="要翻译的文件"),
payload: Json[TranslatePayload] = Form( payload: Json[TranslatePayload] = Form(
..., description="包含工作流参数的JSON字符串结构与JSON接口的payload一致" ..., description="包含工作流参数的JSON字符串 (详见接口文档说明)"
), ),
): ):
task_id = uuid.uuid4().hex[:8] task_id = uuid.uuid4().hex[:8]
@@ -2594,7 +2686,8 @@ def find_free_port(start_port):
port += 1 port += 1
def run_app(host=None,port: int | None = None,enable_CORS=False,allow_origin_regex=r"^(https?://.*|null|file://.*)$"): def run_app(host=None, port: int | None = None, enable_CORS=False,
allow_origin_regex=r"^(https?://.*|null|file://.*)$"):
initial_port = port or int(os.environ.get("DOCUTRANSLATE_PORT", 8010)) initial_port = port or int(os.environ.get("DOCUTRANSLATE_PORT", 8010))
try: try:
port_to_use = find_free_port(initial_port) port_to_use = find_free_port(initial_port)