From 70a444f2b71d2403963ef0e7304f06816c2e17d1 Mon Sep 17 00:00:00 2001 From: xunbu Date: Sat, 27 Dec 2025 21:26:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0provider=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docutranslate/agents/agent.py | 9 +- .../agents/thinking/thinking_factory.py | 86 +++++++++---------- .../ai_translator/ass_translator.py | 3 +- .../translator/ai_translator/base.py | 29 +++---- .../ai_translator/docx_translator.py | 3 +- .../ai_translator/epub_translator.py | 3 +- .../ai_translator/html_translator.py | 3 +- .../ai_translator/json_translator.py | 3 +- .../translator/ai_translator/md_translator.py | 3 +- .../ai_translator/pptx_translator.py | 3 +- .../ai_translator/srt_translator.py | 3 +- .../ai_translator/txt_translator.py | 3 +- .../ai_translator/xlsx_translator.py | 3 +- 13 files changed, 83 insertions(+), 71 deletions(-) diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 32e6142..7b909f5 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -15,7 +15,7 @@ from urllib.parse import urlparse import httpx import tiktoken -from docutranslate.agents.thinking.thinking_factory import get_thinking_mode +from docutranslate.agents.thinking.thinking_factory import get_thinking_mode, ProviderType from docutranslate.logger import global_logger from docutranslate.utils.utils import get_httpx_proxies @@ -55,6 +55,7 @@ class AgentConfig: force_json: bool = False rpm: int | None = None # 每分钟请求数限制 tpm: int | None = None # 每分钟Token数限制 + provider:ProviderType|None=None class TotalErrorCounter: @@ -281,7 +282,7 @@ class Agent: self.baseurl = config.base_url.strip() if self.baseurl.endswith("/"): self.baseurl = self.baseurl[:-1] - self.domain = urlparse(self.baseurl).netloc + self.domain = urlparse(self.baseurl).netloc.strip() self.key = config.api_key.strip() if config.api_key else "xx" self.model_id = config.model_id.strip() self.system_prompt = "" @@ -302,6 +303,8 @@ class Agent: # 新增:初始化 encoding 用于估算 self.encoding = self._get_encoding_for_model(self.model_id) + self.provider=config.provider if config.provider is not None else self.domain + def _get_encoding_for_model(self, model_name: str): """获取 tiktoken encoding,如果失败则使用 cl100k_base 兜底""" try: @@ -322,7 +325,7 @@ class Agent: return len(text) // 4 def _add_thinking_mode(self, data: dict): - thinking_mode_result = get_thinking_mode(self.domain, data.get("model")) + thinking_mode_result = get_thinking_mode(self.provider, data.get("model")) if thinking_mode_result is None: return field_thinking, val_enable, val_disable = thinking_mode_result diff --git a/docutranslate/agents/thinking/thinking_factory.py b/docutranslate/agents/thinking/thinking_factory.py index a1aff52..4c0a8e9 100644 --- a/docutranslate/agents/thinking/thinking_factory.py +++ b/docutranslate/agents/thinking/thinking_factory.py @@ -1,42 +1,44 @@ -from typing import TypeAlias +from typing import TypeAlias, Literal, Any -mode_type:TypeAlias=str -thinking_field: TypeAlias=str -enable_value: TypeAlias= str | dict -disable_value: TypeAlias= str | dict +ProviderType: TypeAlias = Literal["ollama", "open.bigmodel.cn", "dashscope.aliyuncs.com", "ark.cn-beijing.volces.com", "generativelanguage.googleapis.com", "api.siliconflow.cn", "api.302.ai"]|str +ModeType: TypeAlias = Literal["ollama", "bigmodel", "aliyun", "volces", "google", "siliconflow", "default"] +ThinkingField: TypeAlias = str +EnableValueType: TypeAlias = str | dict[str,Any] | bool +DisableValueType: TypeAlias = str | dict[str,Any] | bool +ThinkingConfig: TypeAlias= tuple[ThinkingField, EnableValueType, DisableValueType] + +thinking_mode: dict[ModeType,ThinkingConfig] = { + "ollama": ("reasoning_effort", "medium", "none"), + "bigmodel": ("thinking", {"type": "enabled"}, {"type": "disabled"}), + "aliyun": ( + "extra_body", + {"enable_thinking": True}, + {"enable_thinking": False}, + ), + "volces": ( + "thinking", + {"type": "enabled"}, + {"type": "disabled"}, + ), + "google": ( + "extra_body", + { + "google": { + "thinking_config": {"thinking_budget": -1, "include_thoughts": True} + } + }, + { + "google": { + "thinking_config": {"thinking_budget": 0, "include_thoughts": False} + } + }, + ), + "siliconflow": ("enable_thinking", True, False), + "default": ("reasoning_effort", "medium", "minimal"), +} -thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]={ - "bigmodel": ("thinking", {"type": "enabled"}, {"type": "disabled"}), - "aliyun": ( - "extra_body", - {"enable_thinking": True}, - {"enable_thinking": False}, - ), - "volces": ( - "thinking", - {"type": "enabled"}, - {"type": "disabled"}, - ), - "google": ( - "extra_body", - { - "google": { - "thinking_config": {"thinking_budget": -1, "include_thoughts": True} - } - }, - { - "google": { - "thinking_config": {"thinking_budget": 0, "include_thoughts": False} - } - }, - ), - "siliconflow": ("enable_thinking", True, False), - "default":("reasoning_effort","medium","minimal"), - } - - -def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str | dict] | None: +def get_thinking_mode_by_model_id(model_id: str) -> ThinkingConfig : model_id = model_id.strip().lower() if "glm-4.5" in model_id: return thinking_mode["bigmodel"] @@ -49,8 +51,8 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str | return thinking_mode["default"] -def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None: - provider = provider.strip() +def get_thinking_mode(provider: ProviderType, model_id: str) -> ThinkingConfig : + provider = provider if provider == "open.bigmodel.cn": return thinking_mode["bigmodel"] elif provider == "dashscope.aliyuncs.com": @@ -63,8 +65,6 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st return thinking_mode["siliconflow"] elif provider == "api.302.ai": return get_thinking_mode_by_model_id(model_id) - return thinking_mode["default"] - - -# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool): -# pass + elif provider == "ollama": + return thinking_mode["ollama"] + return thinking_mode["default"] \ No newline at end of file diff --git a/docutranslate/translator/ai_translator/ass_translator.py b/docutranslate/translator/ai_translator/ass_translator.py index 31c1321..cdef8b3 100644 --- a/docutranslate/translator/ai_translator/ass_translator.py +++ b/docutranslate/translator/ai_translator/ass_translator.py @@ -42,7 +42,8 @@ class AssTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/base.py b/docutranslate/translator/ai_translator/base.py index 2dda397..f9603fe 100644 --- a/docutranslate/translator/ai_translator/base.py +++ b/docutranslate/translator/ai_translator/base.py @@ -27,6 +27,7 @@ class AiTranslatorConfig(TranslatorConfig, AgentConfig): glossary_agent_config: GlossaryAgentConfig | None = None skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项 + T = TypeVar("T", bound=Document) @@ -41,7 +42,7 @@ class AiTranslator(Translator[T]): self.glossary_agent = None self.glossary_dict_gen = None if not self.skip_translate and ( - config.base_url is None or config.api_key is None or config.model_id is None + config.base_url is None or config.api_key is None or config.model_id is None ): raise ValueError( "skip_translate不为false时,base_url、api_key、model_id为必填项" @@ -52,23 +53,21 @@ class AiTranslator(Translator[T]): self.glossary_agent = GlossaryAgent(config.glossary_agent_config) else: glossary_agent_config = GlossaryAgentConfig( - to_lang=config.to_lang, - base_url=config.base_url, - api_key=config.api_key, - model_id=config.model_id, - temperature=config.temperature, - thinking=config.thinking, - concurrent=config.concurrent, - timeout=config.timeout, - logger=self.logger, - retry=config.retry, - system_proxy_enable=config.system_proxy_enable, - force_json=config.force_json, + to_lang=config.to_lang, base_url=config.base_url, + api_key=config.api_key, model_id=config.model_id, temperature=config.temperature, + thinking=config.thinking, concurrent=config.concurrent, timeout=config.timeout, + logger=self.logger, retry=config.retry, + system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, + rpm=config.rpm, + tpm=config.tpm, + provider=config.provider, ) self.glossary_agent = GlossaryAgent(glossary_agent_config) @abstractmethod - def translate(self, document: T) -> Document: ... + def translate(self, document: T) -> Document: + ... @abstractmethod - async def translate_async(self, document: T) -> Document: ... + async def translate_async(self, document: T) -> Document: + ... diff --git a/docutranslate/translator/ai_translator/docx_translator.py b/docutranslate/translator/ai_translator/docx_translator.py index d9e9970..6947a71 100644 --- a/docutranslate/translator/ai_translator/docx_translator.py +++ b/docutranslate/translator/ai_translator/docx_translator.py @@ -131,7 +131,8 @@ class DocxTranslator(AiTranslator): logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry, system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/epub_translator.py b/docutranslate/translator/ai_translator/epub_translator.py index 96000be..9f79e76 100644 --- a/docutranslate/translator/ai_translator/epub_translator.py +++ b/docutranslate/translator/ai_translator/epub_translator.py @@ -53,7 +53,8 @@ class EpubTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/html_translator.py b/docutranslate/translator/ai_translator/html_translator.py index 0d6e3f1..e09ae8f 100644 --- a/docutranslate/translator/ai_translator/html_translator.py +++ b/docutranslate/translator/ai_translator/html_translator.py @@ -71,7 +71,8 @@ class HtmlTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/json_translator.py b/docutranslate/translator/ai_translator/json_translator.py index 211fbdc..33cbea5 100644 --- a/docutranslate/translator/ai_translator/json_translator.py +++ b/docutranslate/translator/ai_translator/json_translator.py @@ -39,7 +39,8 @@ class JsonTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.json_paths = config.json_paths diff --git a/docutranslate/translator/ai_translator/md_translator.py b/docutranslate/translator/ai_translator/md_translator.py index 5600ccf..640c2c7 100644 --- a/docutranslate/translator/ai_translator/md_translator.py +++ b/docutranslate/translator/ai_translator/md_translator.py @@ -42,7 +42,8 @@ class MDTranslator(AiTranslator): retry=config.retry, system_proxy_enable=config.system_proxy_enable, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = MDTranslateAgent(agent_config) diff --git a/docutranslate/translator/ai_translator/pptx_translator.py b/docutranslate/translator/ai_translator/pptx_translator.py index f2fdfb0..13678be 100644 --- a/docutranslate/translator/ai_translator/pptx_translator.py +++ b/docutranslate/translator/ai_translator/pptx_translator.py @@ -48,7 +48,8 @@ class PPTXTranslator(AiTranslator): logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry, system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/srt_translator.py b/docutranslate/translator/ai_translator/srt_translator.py index 6d7a18d..31846d9 100644 --- a/docutranslate/translator/ai_translator/srt_translator.py +++ b/docutranslate/translator/ai_translator/srt_translator.py @@ -44,7 +44,8 @@ class SrtTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/txt_translator.py b/docutranslate/translator/ai_translator/txt_translator.py index 2b8e8b3..47146ed 100644 --- a/docutranslate/translator/ai_translator/txt_translator.py +++ b/docutranslate/translator/ai_translator/txt_translator.py @@ -76,7 +76,8 @@ class TXTTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode diff --git a/docutranslate/translator/ai_translator/xlsx_translator.py b/docutranslate/translator/ai_translator/xlsx_translator.py index 8d63963..d75a1c8 100644 --- a/docutranslate/translator/ai_translator/xlsx_translator.py +++ b/docutranslate/translator/ai_translator/xlsx_translator.py @@ -45,7 +45,8 @@ class XlsxTranslator(AiTranslator): system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, rpm=config.rpm, - tpm=config.tpm + tpm=config.tpm, + provider=config.provider, ) self.translate_agent = SegmentsTranslateAgent(agent_config) self.insert_mode = config.insert_mode