From 6f4e5195c55fffda614cc87f0b95fd40ba89a20b Mon Sep 17 00:00:00 2001 From: xunbu Date: Sat, 27 Dec 2025 23:22:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0provider?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docutranslate/app.py | 269 +++++++++++++++++++++++++++++-------------- 1 file changed, 181 insertions(+), 88 deletions(-) diff --git a/docutranslate/app.py b/docutranslate/app.py index 5647d96..229c468 100644 --- a/docutranslate/app.py +++ b/docutranslate/app.py @@ -53,6 +53,7 @@ from pydantic import ( AliasChoices, ConfigDict, Json, + TypeAdapter, # Added TypeAdapter ) from docutranslate import __version__ @@ -176,14 +177,43 @@ def _create_default_task_state() -> Dict[str, Any]: } +def get_workflow_type_from_filename(filename: str) -> str: + """根据文件扩展名自动选择 workflow_type""" + ext = Path(filename).suffix.lower() + if ext in [".pdf",".png",".jpg"]: + return "markdown_based" + elif ext in [".md", ".markdown"]: + return "markdown_based" + elif ext in [".docx",".doc"]: + return "docx" + elif ext in [".csv",".xlsx",".xls"]: + return "xlsx" + elif ext in [".pptx","ppt"]: + return "pptx" + elif ext in [".json"]: + return "json" + elif ext in [".srt"]: + return "srt" + elif ext in [".ass"]: + return "ass" + elif ext in [".epub"]: + return "epub" + elif ext in [".html", ".htm"]: + return "html" + elif ext in [".txt"]: + return "txt" + else: + return "txt" + + # --- 日志处理器 --- class QueueAndHistoryHandler(logging.Handler): def __init__( - self, - queue_ref: asyncio.Queue, - history_list_ref: List[str], - max_history_items: int, - task_id: str, + self, + queue_ref: asyncio.Queue, + history_list_ref: List[str], + max_history_items: int, + task_id: str, ): super().__init__() self.queue = queue_ref @@ -279,7 +309,6 @@ DocuTranslate 后端服务 API,提供文档翻译、状态查询、结果下 **版本**: {__version__} """, version=__version__, - openapi_tags=tags_metadata, ) # mimetypes.add_type("application/wasm", ".wasm") service_router = APIRouter(prefix="/service", tags=["Service API"]) @@ -440,18 +469,26 @@ class BaseWorkflowParams(BaseModel): if not values.get("skip_translate"): # Check for standard keys or their aliases if not (values.get("base_url") or values.get("baseurl")): - raise ValueError( - "当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。" - ) + # Auto 模式在校验前不强制要求 base_url + if values.get("workflow_type") != "auto": + raise ValueError( + "当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。" + ) if not values.get("model_id"): - raise ValueError( - "当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。" - ) + if values.get("workflow_type") != "auto": + raise ValueError( + "当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。" + ) # 如果跳过翻译,则不进行任何检查,允许 base_url 等字段为空 return values # 2. 为每个工作流创建独立的参数模型 +class AutoWorkflowParams(BaseWorkflowParams): + workflow_type: Literal["auto"] = Field(..., description="根据文件后缀自动选择工作流。") + model_config = ConfigDict(extra='allow') + + class MarkdownWorkflowParams(BaseWorkflowParams): workflow_type: Literal["markdown_based"] = Field( ..., description="指定使用基于Markdown的翻译工作流。" @@ -622,7 +659,7 @@ class HtmlWorkflowParams(BaseWorkflowParams): ) insert_mode: Literal["replace", "append", "prepend"] = Field( "replace", - description="翻译文本的插入模式。'replace':替换原文,'append':附加到原文后,'prepend':附加到原文前。", + description="翻译文本的插入模式。'replace':替换原文,'append' :附加到原文后,'prepend':附加到原文前。", ) separator: str = Field( " ", @@ -673,6 +710,7 @@ class PPTXWorkflowParams(BaseWorkflowParams): # 3. 使用可辨识联合类型(Discriminated Union)将它们组合起来 TranslatePayload = Annotated[ Union[ + AutoWorkflowParams, MarkdownWorkflowParams, TextWorkflowParams, JsonWorkflowParams, @@ -714,6 +752,17 @@ class TranslateServiceRequest(BaseModel): model_config = ConfigDict( json_schema_extra={ "examples": [ + { + "file_name": "auto_detect_doc.pdf", + "file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...", + "payload": { + "workflow_type": "auto", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-your-api-key-here", + "model_id": "gpt-4o", + "to_lang": "中文", + }, + }, { "file_name": "annual_report_203.pdf", "file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...", @@ -956,10 +1005,10 @@ class TranslateServiceRequest(BaseModel): # --- Background Task Logic --- async def _perform_translation( - task_id: str, - payload: TranslatePayload, - file_contents: bytes, - original_filename: str, + task_id: str, + payload: TranslatePayload, + file_contents: bytes, + original_filename: str, ): task_state = tasks_state[task_id] log_queue = tasks_log_queues[task_id] @@ -1462,7 +1511,6 @@ async def _perform_translation( # 定义导出函数映射 export_map = {} - if isinstance(workflow, MDFormatsExportable): export_map["markdown"] = ( workflow.export_to_markdown, @@ -1660,11 +1708,37 @@ async def _perform_translation( # --- 核心任务启动逻辑 --- async def _start_translation_task( - task_id: str, - payload: TranslatePayload, - file_contents: bytes, - original_filename: str, + task_id: str, + payload: TranslatePayload, + file_contents: bytes, + original_filename: str, ): + # --- 新增: Auto 工作流路由逻辑 --- + if payload.workflow_type == "auto": + detected_type = get_workflow_type_from_filename(original_filename) + print(f"[{task_id}] 自动识别工作流: {original_filename} -> {detected_type}") + + # 将参数转换为目标具体工作流类型所需的字典 + payload_data = payload.model_dump() + payload_data["workflow_type"] = detected_type + + # 针对特定格式的默认策略 + if detected_type == "json" and not payload_data.get("json_paths"): + payload_data["json_paths"] = ["$..*"] # 默认翻译所有内容 + + if detected_type == "markdown_based" and not payload_data.get("convert_engine"): + if Path(original_filename).suffix.lower() == ".pdf": + payload_data["convert_engine"] = "mineru" if not DOCLING_EXIST else "docling" + else: + payload_data["convert_engine"] = "identity" + + # 重新校验为具体的 Payload 类型 + try: + payload = TypeAdapter(TranslatePayload).validate_python(payload_data) + except Exception as e: + raise HTTPException(status_code=400, detail=f"自动转换工作流参数失败: {e}") + # ----------------------------- + if task_id not in tasks_state: tasks_state[task_id] = _create_default_task_state() tasks_log_queues[task_id] = asyncio.Queue() @@ -1672,9 +1746,9 @@ async def _start_translation_task( task_state = tasks_state[task_id] if ( - task_state["is_processing"] - and task_state["current_task_ref"] - and not task_state["current_task_ref"].done() + task_state["is_processing"] + and task_state["current_task_ref"] + and not task_state["current_task_ref"].done() ): raise HTTPException( status_code=429, detail=f"任务ID '{task_id}' 正在进行中,请稍后再试。" @@ -1694,7 +1768,7 @@ async def _start_translation_task( "error_flag": False, "download_ready": False, "workflow_instance": None, - "original_filename_stem": safe_stem, # 存入安全的stem + "original_filename_stem": safe_stem, # 存入安全的stem "original_filename": original_filename, "task_start_time": time.time(), "task_end_time": 0, @@ -1747,9 +1821,9 @@ def _cancel_translation_logic(task_id: str): if not task_state: raise HTTPException(status_code=404, detail=f"找不到任务ID '{task_id}'。") if ( - not task_state - or not task_state["is_processing"] - or not task_state["current_task_ref"] + not task_state + or not task_state["is_processing"] + or not task_state["current_task_ref"] ): raise HTTPException( status_code=400, detail=f"任务ID '{task_id}' 没有正在进行的翻译任务可取消。" @@ -1778,7 +1852,8 @@ def _cancel_translation_logic(task_id: str): description=""" 接收一个包含文件内容(Base64编码)和工作流参数的JSON请求,启动一个后台翻译任务。 -- **工作流选择**: 请求体中的 `payload.workflow_type` 字段决定了本次任务的类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`)。 +- **工作流选择**: `payload.workflow_type` 决定任务类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`, `auto`)。 +- **Auto 模式**: 当设置为 `auto` 时,后端将根据 `file_name` 的扩展名自动选择最合适的工作流。 - **动态参数**: 根据所选工作流,API需要不同的参数集。请参考下面的Schema或示例。 - **异步处理**: 此端点会立即返回任务ID,客户端需轮询状态接口获取进度。 """, @@ -1803,9 +1878,9 @@ def _cancel_translation_logic(task_id: str): }, ) async def service_translate( - request: TranslateServiceRequest = Body( - ..., description="翻译任务的详细参数和文件内容。" - ) + request: TranslateServiceRequest = Body( + ..., description="翻译任务的详细参数和文件内容。" + ) ): task_id = uuid.uuid4().hex[:8] @@ -1840,13 +1915,30 @@ async def service_translate( "/translate/file", summary="提交翻译任务 (文件上传)", description=""" -接收一个上传的文件和包含工作流参数的JSON字符串,启动一个后台翻译任务。 + 通过 `multipart/form-data` 方式上传文件并启动翻译任务。 -- **工作流选择**: `payload` 表单字段中的 `workflow_type` 字段决定了本次任务的类型。 -- **文件上传**: 通过 `file` 字段上传文件,替代JSON接口中的 `file_content` 和 `file_name`。 -- **参数传递**: `payload` 字段应为一个符合 JSON 格式的字符串,其结构与 `/service/translate` 中的 `payload` 字段完全一致。 -- **异步处理**: 此端点会立即返回任务ID,客户端需轮询状态接口获取进度。 -""", + 此接口适用于直接上传二进制文件(如 PDF, Docx 等),无需先进行 Base64 编码。 + + ### 参数说明 + - **file**: (必须) 要翻译的二进制文件。 + - **payload**: (必须) 包含工作流配置的 **JSON 字符串**。 + - 必须包含 `workflow_type` (如 `auto`, `docx`, `markdown_based` 等)。 + - 其他参数根据 `workflow_type` 不同而变化 (详见 `TranslatePayload` 模型)。 + + ### Payload 示例 (JSON String) + ```json + { + "workflow_type": "auto", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-xxxxxx", + "model_id": "gpt-4o", + "to_lang": "中文" + } + ``` + + ### 响应 + 返回包含 `task_id` 的 JSON 对象。客户端需使用此 ID 轮询 `/service/status/{task_id}` 接口获取进度。 + """, responses={ 200: { "description": "翻译任务已成功启动。", @@ -1868,10 +1960,10 @@ async def service_translate( }, ) async def service_translate_file( - file: UploadFile = File(..., description="要翻译的文件"), - payload: Json[TranslatePayload] = Form( - ..., description="包含工作流参数的JSON字符串,结构与JSON接口的payload一致。" - ), + file: UploadFile = File(..., description="要翻译的文件"), + payload: Json[TranslatePayload] = Form( + ..., description="包含工作流参数的JSON字符串 (详见接口文档说明)。" + ), ): task_id = uuid.uuid4().hex[:8] @@ -1925,9 +2017,9 @@ async def service_release_task(task_id: str): task_state = tasks_state.get(task_id) message_parts = [] if ( - task_state - and task_state.get("is_processing") - and task_state.get("current_task_ref") + task_state + and task_state.get("is_processing") + and task_state.get("current_task_ref") ): try: print(f"[{task_id}] 任务正在进行中,将在释放前尝试取消。") @@ -2167,9 +2259,9 @@ async def service_release_task(task_id: str): }, ) async def service_get_status( - task_id: str = FastApiPath( - ..., description="要查询状态的任务的ID", examples=["b2865b93"] - ) + task_id: str = FastApiPath( + ..., description="要查询状态的任务的ID", examples=["b2865b93"] + ) ): task_state = tasks_state.get(task_id) if not task_state: @@ -2273,14 +2365,14 @@ FileType = Literal[ }, ) async def service_download_file( - task_id: str = FastApiPath( - ..., description="已完成任务的ID", examples=["b2865b93"] - ), - file_type: FileType = FastApiPath( - ..., - description="要下载的文件类型。", - examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"], - ), + task_id: str = FastApiPath( + ..., description="已完成任务的ID", examples=["b2865b93"] + ), + file_type: FileType = FastApiPath( + ..., + description="要下载的文件类型。", + examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"], + ), ): task_state = tasks_state.get(task_id) if not task_state: @@ -2319,12 +2411,12 @@ async def service_download_file( }, ) async def service_download_attachment( - task_id: str = FastApiPath( - ..., description="已完成任务的ID", examples=["g1h2i3j4"] - ), - identifier: str = FastApiPath( - ..., description="要下载的附件的标识符。", examples=["glossary"] - ), + task_id: str = FastApiPath( + ..., description="已完成任务的ID", examples=["g1h2i3j4"] + ), + identifier: str = FastApiPath( + ..., description="要下载的附件的标识符。", examples=["glossary"] + ), ): task_state = tasks_state.get(task_id) if not task_state: @@ -2404,14 +2496,14 @@ async def service_download_attachment( }, ) async def service_content( - task_id: str = FastApiPath( - ..., description="已完成任务的ID", examples=["b2865b93"] - ), - file_type: FileType = FastApiPath( - ..., - description="要获取内容的文件类型。", - examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"], - ), + task_id: str = FastApiPath( + ..., description="已完成任务的ID", examples=["b2865b93"] + ), + file_type: FileType = FastApiPath( + ..., + description="要获取内容的文件类型。", + examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"], + ), ): task_state = tasks_state.get(task_id) if not task_state: @@ -2527,22 +2619,22 @@ async def redoc_html(): @app.post("/temp/translate", tags=["Temp"]) async def temp_translate( - base_url: str = Body(...), - api_key: str = Body("xx"), - model_id: str = Body(...), - mineru_token: Optional[str] = Body(None), - file_name: str = Body(...), - file_content: str = Body(...), - to_lang: str = Body("中文"), - concurrent: int = Body(default_params["concurrent"]), - temperature: float = Body(default_params["temperature"]), - thinking: ThinkingMode = Body(default_params["thinking"]), - chunk_size: int = Body(default_params["chunk_size"]), - custom_prompt: Optional[str] = Body(None), - model_version: Literal["pipeline", "vlm"] = Body("vlm"), - glossary_dict: Optional[Dict[str, str]] = Body(None), - rpm: Optional[int] = Body(None), - tpm: Optional[int] = Body(None), + base_url: str = Body(...), + api_key: str = Body("xx"), + model_id: str = Body(...), + mineru_token: Optional[str] = Body(None), + file_name: str = Body(...), + file_content: str = Body(...), + to_lang: str = Body("中文"), + concurrent: int = Body(default_params["concurrent"]), + temperature: float = Body(default_params["temperature"]), + thinking: ThinkingMode = Body(default_params["thinking"]), + chunk_size: int = Body(default_params["chunk_size"]), + custom_prompt: Optional[str] = Body(None), + model_version: Literal["pipeline", "vlm"] = Body("vlm"), + glossary_dict: Optional[Dict[str, str]] = Body(None), + rpm: Optional[int] = Body(None), + tpm: Optional[int] = Body(None), ): file_name = Path(file_name) try: @@ -2594,7 +2686,8 @@ def find_free_port(start_port): port += 1 -def run_app(host=None,port: int | None = None,enable_CORS=False,allow_origin_regex=r"^(https?://.*|null|file://.*)$"): +def run_app(host=None, port: int | None = None, enable_CORS=False, + allow_origin_regex=r"^(https?://.*|null|file://.*)$"): initial_port = port or int(os.environ.get("DOCUTRANSLATE_PORT", 8010)) try: port_to_use = find_free_port(initial_port)