增加provider
This commit is contained in:
@@ -53,6 +53,7 @@ from pydantic import (
|
||||
AliasChoices,
|
||||
ConfigDict,
|
||||
Json,
|
||||
TypeAdapter, # Added TypeAdapter
|
||||
)
|
||||
|
||||
from docutranslate import __version__
|
||||
@@ -176,14 +177,43 @@ def _create_default_task_state() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_workflow_type_from_filename(filename: str) -> str:
|
||||
"""根据文件扩展名自动选择 workflow_type"""
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext in [".pdf",".png",".jpg"]:
|
||||
return "markdown_based"
|
||||
elif ext in [".md", ".markdown"]:
|
||||
return "markdown_based"
|
||||
elif ext in [".docx",".doc"]:
|
||||
return "docx"
|
||||
elif ext in [".csv",".xlsx",".xls"]:
|
||||
return "xlsx"
|
||||
elif ext in [".pptx","ppt"]:
|
||||
return "pptx"
|
||||
elif ext in [".json"]:
|
||||
return "json"
|
||||
elif ext in [".srt"]:
|
||||
return "srt"
|
||||
elif ext in [".ass"]:
|
||||
return "ass"
|
||||
elif ext in [".epub"]:
|
||||
return "epub"
|
||||
elif ext in [".html", ".htm"]:
|
||||
return "html"
|
||||
elif ext in [".txt"]:
|
||||
return "txt"
|
||||
else:
|
||||
return "txt"
|
||||
|
||||
|
||||
# --- 日志处理器 ---
|
||||
class QueueAndHistoryHandler(logging.Handler):
|
||||
def __init__(
|
||||
self,
|
||||
queue_ref: asyncio.Queue,
|
||||
history_list_ref: List[str],
|
||||
max_history_items: int,
|
||||
task_id: str,
|
||||
self,
|
||||
queue_ref: asyncio.Queue,
|
||||
history_list_ref: List[str],
|
||||
max_history_items: int,
|
||||
task_id: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.queue = queue_ref
|
||||
@@ -279,7 +309,6 @@ DocuTranslate 后端服务 API,提供文档翻译、状态查询、结果下
|
||||
**版本**: {__version__}
|
||||
""",
|
||||
version=__version__,
|
||||
openapi_tags=tags_metadata,
|
||||
)
|
||||
# mimetypes.add_type("application/wasm", ".wasm")
|
||||
service_router = APIRouter(prefix="/service", tags=["Service API"])
|
||||
@@ -440,18 +469,26 @@ class BaseWorkflowParams(BaseModel):
|
||||
if not values.get("skip_translate"):
|
||||
# Check for standard keys or their aliases
|
||||
if not (values.get("base_url") or values.get("baseurl")):
|
||||
raise ValueError(
|
||||
"当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。"
|
||||
)
|
||||
# Auto 模式在校验前不强制要求 base_url
|
||||
if values.get("workflow_type") != "auto":
|
||||
raise ValueError(
|
||||
"当 `skip_translate` 为 `False` 时, `base_url` 或 `baseurl` 字段是必须的。"
|
||||
)
|
||||
if not values.get("model_id"):
|
||||
raise ValueError(
|
||||
"当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。"
|
||||
)
|
||||
if values.get("workflow_type") != "auto":
|
||||
raise ValueError(
|
||||
"当 `skip_translate` 为 `False` 时, `model_id` 字段是必须的。"
|
||||
)
|
||||
# 如果跳过翻译,则不进行任何检查,允许 base_url 等字段为空
|
||||
return values
|
||||
|
||||
|
||||
# 2. 为每个工作流创建独立的参数模型
|
||||
class AutoWorkflowParams(BaseWorkflowParams):
|
||||
workflow_type: Literal["auto"] = Field(..., description="根据文件后缀自动选择工作流。")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class MarkdownWorkflowParams(BaseWorkflowParams):
|
||||
workflow_type: Literal["markdown_based"] = Field(
|
||||
..., description="指定使用基于Markdown的翻译工作流。"
|
||||
@@ -622,7 +659,7 @@ class HtmlWorkflowParams(BaseWorkflowParams):
|
||||
)
|
||||
insert_mode: Literal["replace", "append", "prepend"] = Field(
|
||||
"replace",
|
||||
description="翻译文本的插入模式。'replace':替换原文,'append':附加到原文后,'prepend':附加到原文前。",
|
||||
description="翻译文本的插入模式。'replace':替换原文,'append' :附加到原文后,'prepend':附加到原文前。",
|
||||
)
|
||||
separator: str = Field(
|
||||
" ",
|
||||
@@ -673,6 +710,7 @@ class PPTXWorkflowParams(BaseWorkflowParams):
|
||||
# 3. 使用可辨识联合类型(Discriminated Union)将它们组合起来
|
||||
TranslatePayload = Annotated[
|
||||
Union[
|
||||
AutoWorkflowParams,
|
||||
MarkdownWorkflowParams,
|
||||
TextWorkflowParams,
|
||||
JsonWorkflowParams,
|
||||
@@ -714,6 +752,17 @@ class TranslateServiceRequest(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"file_name": "auto_detect_doc.pdf",
|
||||
"file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...",
|
||||
"payload": {
|
||||
"workflow_type": "auto",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-your-api-key-here",
|
||||
"model_id": "gpt-4o",
|
||||
"to_lang": "中文",
|
||||
},
|
||||
},
|
||||
{
|
||||
"file_name": "annual_report_203.pdf",
|
||||
"file_content": "JVBERi0xLjcKJeLjz9MKMSAwIG9iago8PC9...",
|
||||
@@ -956,10 +1005,10 @@ class TranslateServiceRequest(BaseModel):
|
||||
|
||||
# --- Background Task Logic ---
|
||||
async def _perform_translation(
|
||||
task_id: str,
|
||||
payload: TranslatePayload,
|
||||
file_contents: bytes,
|
||||
original_filename: str,
|
||||
task_id: str,
|
||||
payload: TranslatePayload,
|
||||
file_contents: bytes,
|
||||
original_filename: str,
|
||||
):
|
||||
task_state = tasks_state[task_id]
|
||||
log_queue = tasks_log_queues[task_id]
|
||||
@@ -1462,7 +1511,6 @@ async def _perform_translation(
|
||||
# 定义导出函数映射
|
||||
export_map = {}
|
||||
|
||||
|
||||
if isinstance(workflow, MDFormatsExportable):
|
||||
export_map["markdown"] = (
|
||||
workflow.export_to_markdown,
|
||||
@@ -1660,11 +1708,37 @@ async def _perform_translation(
|
||||
|
||||
# --- 核心任务启动逻辑 ---
|
||||
async def _start_translation_task(
|
||||
task_id: str,
|
||||
payload: TranslatePayload,
|
||||
file_contents: bytes,
|
||||
original_filename: str,
|
||||
task_id: str,
|
||||
payload: TranslatePayload,
|
||||
file_contents: bytes,
|
||||
original_filename: str,
|
||||
):
|
||||
# --- 新增: Auto 工作流路由逻辑 ---
|
||||
if payload.workflow_type == "auto":
|
||||
detected_type = get_workflow_type_from_filename(original_filename)
|
||||
print(f"[{task_id}] 自动识别工作流: {original_filename} -> {detected_type}")
|
||||
|
||||
# 将参数转换为目标具体工作流类型所需的字典
|
||||
payload_data = payload.model_dump()
|
||||
payload_data["workflow_type"] = detected_type
|
||||
|
||||
# 针对特定格式的默认策略
|
||||
if detected_type == "json" and not payload_data.get("json_paths"):
|
||||
payload_data["json_paths"] = ["$..*"] # 默认翻译所有内容
|
||||
|
||||
if detected_type == "markdown_based" and not payload_data.get("convert_engine"):
|
||||
if Path(original_filename).suffix.lower() == ".pdf":
|
||||
payload_data["convert_engine"] = "mineru" if not DOCLING_EXIST else "docling"
|
||||
else:
|
||||
payload_data["convert_engine"] = "identity"
|
||||
|
||||
# 重新校验为具体的 Payload 类型
|
||||
try:
|
||||
payload = TypeAdapter(TranslatePayload).validate_python(payload_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"自动转换工作流参数失败: {e}")
|
||||
# -----------------------------
|
||||
|
||||
if task_id not in tasks_state:
|
||||
tasks_state[task_id] = _create_default_task_state()
|
||||
tasks_log_queues[task_id] = asyncio.Queue()
|
||||
@@ -1672,9 +1746,9 @@ async def _start_translation_task(
|
||||
task_state = tasks_state[task_id]
|
||||
|
||||
if (
|
||||
task_state["is_processing"]
|
||||
and task_state["current_task_ref"]
|
||||
and not task_state["current_task_ref"].done()
|
||||
task_state["is_processing"]
|
||||
and task_state["current_task_ref"]
|
||||
and not task_state["current_task_ref"].done()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=429, detail=f"任务ID '{task_id}' 正在进行中,请稍后再试。"
|
||||
@@ -1694,7 +1768,7 @@ async def _start_translation_task(
|
||||
"error_flag": False,
|
||||
"download_ready": False,
|
||||
"workflow_instance": None,
|
||||
"original_filename_stem": safe_stem, # 存入安全的stem
|
||||
"original_filename_stem": safe_stem, # 存入安全的stem
|
||||
"original_filename": original_filename,
|
||||
"task_start_time": time.time(),
|
||||
"task_end_time": 0,
|
||||
@@ -1747,9 +1821,9 @@ def _cancel_translation_logic(task_id: str):
|
||||
if not task_state:
|
||||
raise HTTPException(status_code=404, detail=f"找不到任务ID '{task_id}'。")
|
||||
if (
|
||||
not task_state
|
||||
or not task_state["is_processing"]
|
||||
or not task_state["current_task_ref"]
|
||||
not task_state
|
||||
or not task_state["is_processing"]
|
||||
or not task_state["current_task_ref"]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"任务ID '{task_id}' 没有正在进行的翻译任务可取消。"
|
||||
@@ -1778,7 +1852,8 @@ def _cancel_translation_logic(task_id: str):
|
||||
description="""
|
||||
接收一个包含文件内容(Base64编码)和工作流参数的JSON请求,启动一个后台翻译任务。
|
||||
|
||||
- **工作流选择**: 请求体中的 `payload.workflow_type` 字段决定了本次任务的类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`)。
|
||||
- **工作流选择**: `payload.workflow_type` 决定任务类型(如 `markdown_based`, `txt`, `json`, `xlsx`, `docx`, `srt`, `epub`, `html`, `ass`, `pptx`, `auto`)。
|
||||
- **Auto 模式**: 当设置为 `auto` 时,后端将根据 `file_name` 的扩展名自动选择最合适的工作流。
|
||||
- **动态参数**: 根据所选工作流,API需要不同的参数集。请参考下面的Schema或示例。
|
||||
- **异步处理**: 此端点会立即返回任务ID,客户端需轮询状态接口获取进度。
|
||||
""",
|
||||
@@ -1803,9 +1878,9 @@ def _cancel_translation_logic(task_id: str):
|
||||
},
|
||||
)
|
||||
async def service_translate(
|
||||
request: TranslateServiceRequest = Body(
|
||||
..., description="翻译任务的详细参数和文件内容。"
|
||||
)
|
||||
request: TranslateServiceRequest = Body(
|
||||
..., description="翻译任务的详细参数和文件内容。"
|
||||
)
|
||||
):
|
||||
task_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@@ -1840,13 +1915,30 @@ async def service_translate(
|
||||
"/translate/file",
|
||||
summary="提交翻译任务 (文件上传)",
|
||||
description="""
|
||||
接收一个上传的文件和包含工作流参数的JSON字符串,启动一个后台翻译任务。
|
||||
通过 `multipart/form-data` 方式上传文件并启动翻译任务。
|
||||
|
||||
- **工作流选择**: `payload` 表单字段中的 `workflow_type` 字段决定了本次任务的类型。
|
||||
- **文件上传**: 通过 `file` 字段上传文件,替代JSON接口中的 `file_content` 和 `file_name`。
|
||||
- **参数传递**: `payload` 字段应为一个符合 JSON 格式的字符串,其结构与 `/service/translate` 中的 `payload` 字段完全一致。
|
||||
- **异步处理**: 此端点会立即返回任务ID,客户端需轮询状态接口获取进度。
|
||||
""",
|
||||
此接口适用于直接上传二进制文件(如 PDF, Docx 等),无需先进行 Base64 编码。
|
||||
|
||||
### 参数说明
|
||||
- **file**: (必须) 要翻译的二进制文件。
|
||||
- **payload**: (必须) 包含工作流配置的 **JSON 字符串**。
|
||||
- 必须包含 `workflow_type` (如 `auto`, `docx`, `markdown_based` 等)。
|
||||
- 其他参数根据 `workflow_type` 不同而变化 (详见 `TranslatePayload` 模型)。
|
||||
|
||||
### Payload 示例 (JSON String)
|
||||
```json
|
||||
{
|
||||
"workflow_type": "auto",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxxxxx",
|
||||
"model_id": "gpt-4o",
|
||||
"to_lang": "中文"
|
||||
}
|
||||
```
|
||||
|
||||
### 响应
|
||||
返回包含 `task_id` 的 JSON 对象。客户端需使用此 ID 轮询 `/service/status/{task_id}` 接口获取进度。
|
||||
""",
|
||||
responses={
|
||||
200: {
|
||||
"description": "翻译任务已成功启动。",
|
||||
@@ -1868,10 +1960,10 @@ async def service_translate(
|
||||
},
|
||||
)
|
||||
async def service_translate_file(
|
||||
file: UploadFile = File(..., description="要翻译的文件"),
|
||||
payload: Json[TranslatePayload] = Form(
|
||||
..., description="包含工作流参数的JSON字符串,结构与JSON接口的payload一致。"
|
||||
),
|
||||
file: UploadFile = File(..., description="要翻译的文件"),
|
||||
payload: Json[TranslatePayload] = Form(
|
||||
..., description="包含工作流参数的JSON字符串 (详见接口文档说明)。"
|
||||
),
|
||||
):
|
||||
task_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@@ -1925,9 +2017,9 @@ async def service_release_task(task_id: str):
|
||||
task_state = tasks_state.get(task_id)
|
||||
message_parts = []
|
||||
if (
|
||||
task_state
|
||||
and task_state.get("is_processing")
|
||||
and task_state.get("current_task_ref")
|
||||
task_state
|
||||
and task_state.get("is_processing")
|
||||
and task_state.get("current_task_ref")
|
||||
):
|
||||
try:
|
||||
print(f"[{task_id}] 任务正在进行中,将在释放前尝试取消。")
|
||||
@@ -2167,9 +2259,9 @@ async def service_release_task(task_id: str):
|
||||
},
|
||||
)
|
||||
async def service_get_status(
|
||||
task_id: str = FastApiPath(
|
||||
..., description="要查询状态的任务的ID", examples=["b2865b93"]
|
||||
)
|
||||
task_id: str = FastApiPath(
|
||||
..., description="要查询状态的任务的ID", examples=["b2865b93"]
|
||||
)
|
||||
):
|
||||
task_state = tasks_state.get(task_id)
|
||||
if not task_state:
|
||||
@@ -2273,14 +2365,14 @@ FileType = Literal[
|
||||
},
|
||||
)
|
||||
async def service_download_file(
|
||||
task_id: str = FastApiPath(
|
||||
..., description="已完成任务的ID", examples=["b2865b93"]
|
||||
),
|
||||
file_type: FileType = FastApiPath(
|
||||
...,
|
||||
description="要下载的文件类型。",
|
||||
examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"],
|
||||
),
|
||||
task_id: str = FastApiPath(
|
||||
..., description="已完成任务的ID", examples=["b2865b93"]
|
||||
),
|
||||
file_type: FileType = FastApiPath(
|
||||
...,
|
||||
description="要下载的文件类型。",
|
||||
examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"],
|
||||
),
|
||||
):
|
||||
task_state = tasks_state.get(task_id)
|
||||
if not task_state:
|
||||
@@ -2319,12 +2411,12 @@ async def service_download_file(
|
||||
},
|
||||
)
|
||||
async def service_download_attachment(
|
||||
task_id: str = FastApiPath(
|
||||
..., description="已完成任务的ID", examples=["g1h2i3j4"]
|
||||
),
|
||||
identifier: str = FastApiPath(
|
||||
..., description="要下载的附件的标识符。", examples=["glossary"]
|
||||
),
|
||||
task_id: str = FastApiPath(
|
||||
..., description="已完成任务的ID", examples=["g1h2i3j4"]
|
||||
),
|
||||
identifier: str = FastApiPath(
|
||||
..., description="要下载的附件的标识符。", examples=["glossary"]
|
||||
),
|
||||
):
|
||||
task_state = tasks_state.get(task_id)
|
||||
if not task_state:
|
||||
@@ -2404,14 +2496,14 @@ async def service_download_attachment(
|
||||
},
|
||||
)
|
||||
async def service_content(
|
||||
task_id: str = FastApiPath(
|
||||
..., description="已完成任务的ID", examples=["b2865b93"]
|
||||
),
|
||||
file_type: FileType = FastApiPath(
|
||||
...,
|
||||
description="要获取内容的文件类型。",
|
||||
examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"],
|
||||
),
|
||||
task_id: str = FastApiPath(
|
||||
..., description="已完成任务的ID", examples=["b2865b93"]
|
||||
),
|
||||
file_type: FileType = FastApiPath(
|
||||
...,
|
||||
description="要获取内容的文件类型。",
|
||||
examples=["html", "json", "csv", "docx", "srt", "epub", "ass", "pptx"],
|
||||
),
|
||||
):
|
||||
task_state = tasks_state.get(task_id)
|
||||
if not task_state:
|
||||
@@ -2527,22 +2619,22 @@ async def redoc_html():
|
||||
|
||||
@app.post("/temp/translate", tags=["Temp"])
|
||||
async def temp_translate(
|
||||
base_url: str = Body(...),
|
||||
api_key: str = Body("xx"),
|
||||
model_id: str = Body(...),
|
||||
mineru_token: Optional[str] = Body(None),
|
||||
file_name: str = Body(...),
|
||||
file_content: str = Body(...),
|
||||
to_lang: str = Body("中文"),
|
||||
concurrent: int = Body(default_params["concurrent"]),
|
||||
temperature: float = Body(default_params["temperature"]),
|
||||
thinking: ThinkingMode = Body(default_params["thinking"]),
|
||||
chunk_size: int = Body(default_params["chunk_size"]),
|
||||
custom_prompt: Optional[str] = Body(None),
|
||||
model_version: Literal["pipeline", "vlm"] = Body("vlm"),
|
||||
glossary_dict: Optional[Dict[str, str]] = Body(None),
|
||||
rpm: Optional[int] = Body(None),
|
||||
tpm: Optional[int] = Body(None),
|
||||
base_url: str = Body(...),
|
||||
api_key: str = Body("xx"),
|
||||
model_id: str = Body(...),
|
||||
mineru_token: Optional[str] = Body(None),
|
||||
file_name: str = Body(...),
|
||||
file_content: str = Body(...),
|
||||
to_lang: str = Body("中文"),
|
||||
concurrent: int = Body(default_params["concurrent"]),
|
||||
temperature: float = Body(default_params["temperature"]),
|
||||
thinking: ThinkingMode = Body(default_params["thinking"]),
|
||||
chunk_size: int = Body(default_params["chunk_size"]),
|
||||
custom_prompt: Optional[str] = Body(None),
|
||||
model_version: Literal["pipeline", "vlm"] = Body("vlm"),
|
||||
glossary_dict: Optional[Dict[str, str]] = Body(None),
|
||||
rpm: Optional[int] = Body(None),
|
||||
tpm: Optional[int] = Body(None),
|
||||
):
|
||||
file_name = Path(file_name)
|
||||
try:
|
||||
@@ -2594,7 +2686,8 @@ def find_free_port(start_port):
|
||||
port += 1
|
||||
|
||||
|
||||
def run_app(host=None,port: int | None = None,enable_CORS=False,allow_origin_regex=r"^(https?://.*|null|file://.*)$"):
|
||||
def run_app(host=None, port: int | None = None, enable_CORS=False,
|
||||
allow_origin_regex=r"^(https?://.*|null|file://.*)$"):
|
||||
initial_port = port or int(os.environ.get("DOCUTRANSLATE_PORT", 8010))
|
||||
try:
|
||||
port_to_use = find_free_port(initial_port)
|
||||
|
||||
Reference in New Issue
Block a user