增加provider选项

This commit is contained in:
xunbu
2025-12-27 21:26:58 +08:00
parent 8aa0d6cb8c
commit 70a444f2b7
13 changed files with 83 additions and 71 deletions

View File

@@ -15,7 +15,7 @@ from urllib.parse import urlparse
import httpx import httpx
import tiktoken import tiktoken
from docutranslate.agents.thinking.thinking_factory import get_thinking_mode from docutranslate.agents.thinking.thinking_factory import get_thinking_mode, ProviderType
from docutranslate.logger import global_logger from docutranslate.logger import global_logger
from docutranslate.utils.utils import get_httpx_proxies from docutranslate.utils.utils import get_httpx_proxies
@@ -55,6 +55,7 @@ class AgentConfig:
force_json: bool = False force_json: bool = False
rpm: int | None = None # 每分钟请求数限制 rpm: int | None = None # 每分钟请求数限制
tpm: int | None = None # 每分钟Token数限制 tpm: int | None = None # 每分钟Token数限制
provider:ProviderType|None=None
class TotalErrorCounter: class TotalErrorCounter:
@@ -281,7 +282,7 @@ class Agent:
self.baseurl = config.base_url.strip() self.baseurl = config.base_url.strip()
if self.baseurl.endswith("/"): if self.baseurl.endswith("/"):
self.baseurl = self.baseurl[:-1] self.baseurl = self.baseurl[:-1]
self.domain = urlparse(self.baseurl).netloc self.domain = urlparse(self.baseurl).netloc.strip()
self.key = config.api_key.strip() if config.api_key else "xx" self.key = config.api_key.strip() if config.api_key else "xx"
self.model_id = config.model_id.strip() self.model_id = config.model_id.strip()
self.system_prompt = "" self.system_prompt = ""
@@ -302,6 +303,8 @@ class Agent:
# 新增:初始化 encoding 用于估算 # 新增:初始化 encoding 用于估算
self.encoding = self._get_encoding_for_model(self.model_id) self.encoding = self._get_encoding_for_model(self.model_id)
self.provider=config.provider if config.provider is not None else self.domain
def _get_encoding_for_model(self, model_name: str): def _get_encoding_for_model(self, model_name: str):
"""获取 tiktoken encoding如果失败则使用 cl100k_base 兜底""" """获取 tiktoken encoding如果失败则使用 cl100k_base 兜底"""
try: try:
@@ -322,7 +325,7 @@ class Agent:
return len(text) // 4 return len(text) // 4
def _add_thinking_mode(self, data: dict): def _add_thinking_mode(self, data: dict):
thinking_mode_result = get_thinking_mode(self.domain, data.get("model")) thinking_mode_result = get_thinking_mode(self.provider, data.get("model"))
if thinking_mode_result is None: if thinking_mode_result is None:
return return
field_thinking, val_enable, val_disable = thinking_mode_result field_thinking, val_enable, val_disable = thinking_mode_result

View File

@@ -1,12 +1,14 @@
from typing import TypeAlias from typing import TypeAlias, Literal, Any
mode_type:TypeAlias=str ProviderType: TypeAlias = Literal["ollama", "open.bigmodel.cn", "dashscope.aliyuncs.com", "ark.cn-beijing.volces.com", "generativelanguage.googleapis.com", "api.siliconflow.cn", "api.302.ai"]|str
thinking_field: TypeAlias=str ModeType: TypeAlias = Literal["ollama", "bigmodel", "aliyun", "volces", "google", "siliconflow", "default"]
enable_value: TypeAlias= str | dict ThinkingField: TypeAlias = str
disable_value: TypeAlias= str | dict EnableValueType: TypeAlias = str | dict[str,Any] | bool
DisableValueType: TypeAlias = str | dict[str,Any] | bool
ThinkingConfig: TypeAlias= tuple[ThinkingField, EnableValueType, DisableValueType]
thinking_mode: dict[ModeType,ThinkingConfig] = {
thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]={ "ollama": ("reasoning_effort", "medium", "none"),
"bigmodel": ("thinking", {"type": "enabled"}, {"type": "disabled"}), "bigmodel": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
"aliyun": ( "aliyun": (
"extra_body", "extra_body",
@@ -36,7 +38,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
} }
def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str | dict] | None: def get_thinking_mode_by_model_id(model_id: str) -> ThinkingConfig :
model_id = model_id.strip().lower() model_id = model_id.strip().lower()
if "glm-4.5" in model_id: if "glm-4.5" in model_id:
return thinking_mode["bigmodel"] return thinking_mode["bigmodel"]
@@ -49,8 +51,8 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
return thinking_mode["default"] return thinking_mode["default"]
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None: def get_thinking_mode(provider: ProviderType, model_id: str) -> ThinkingConfig :
provider = provider.strip() provider = provider
if provider == "open.bigmodel.cn": if provider == "open.bigmodel.cn":
return thinking_mode["bigmodel"] return thinking_mode["bigmodel"]
elif provider == "dashscope.aliyuncs.com": elif provider == "dashscope.aliyuncs.com":
@@ -63,8 +65,6 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
return thinking_mode["siliconflow"] return thinking_mode["siliconflow"]
elif provider == "api.302.ai": elif provider == "api.302.ai":
return get_thinking_mode_by_model_id(model_id) return get_thinking_mode_by_model_id(model_id)
elif provider == "ollama":
return thinking_mode["ollama"]
return thinking_mode["default"] return thinking_mode["default"]
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):
# pass

View File

@@ -42,7 +42,8 @@ class AssTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -27,6 +27,7 @@ class AiTranslatorConfig(TranslatorConfig, AgentConfig):
glossary_agent_config: GlossaryAgentConfig | None = None glossary_agent_config: GlossaryAgentConfig | None = None
skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项 skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项
T = TypeVar("T", bound=Document) T = TypeVar("T", bound=Document)
@@ -52,23 +53,21 @@ class AiTranslator(Translator[T]):
self.glossary_agent = GlossaryAgent(config.glossary_agent_config) self.glossary_agent = GlossaryAgent(config.glossary_agent_config)
else: else:
glossary_agent_config = GlossaryAgentConfig( glossary_agent_config = GlossaryAgentConfig(
to_lang=config.to_lang, to_lang=config.to_lang, base_url=config.base_url,
base_url=config.base_url, api_key=config.api_key, model_id=config.model_id, temperature=config.temperature,
api_key=config.api_key, thinking=config.thinking, concurrent=config.concurrent, timeout=config.timeout,
model_id=config.model_id, logger=self.logger, retry=config.retry,
temperature=config.temperature, system_proxy_enable=config.system_proxy_enable, force_json=config.force_json,
thinking=config.thinking, rpm=config.rpm,
concurrent=config.concurrent, tpm=config.tpm,
timeout=config.timeout, provider=config.provider,
logger=self.logger,
retry=config.retry,
system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json,
) )
self.glossary_agent = GlossaryAgent(glossary_agent_config) self.glossary_agent = GlossaryAgent(glossary_agent_config)
@abstractmethod @abstractmethod
def translate(self, document: T) -> Document: ... def translate(self, document: T) -> Document:
...
@abstractmethod @abstractmethod
async def translate_async(self, document: T) -> Document: ... async def translate_async(self, document: T) -> Document:
...

View File

@@ -131,7 +131,8 @@ class DocxTranslator(AiTranslator):
logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry, logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry,
system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, system_proxy_enable=config.system_proxy_enable, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -53,7 +53,8 @@ class EpubTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -71,7 +71,8 @@ class HtmlTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -39,7 +39,8 @@ class JsonTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.json_paths = config.json_paths self.json_paths = config.json_paths

View File

@@ -42,7 +42,8 @@ class MDTranslator(AiTranslator):
retry=config.retry, retry=config.retry,
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = MDTranslateAgent(agent_config) self.translate_agent = MDTranslateAgent(agent_config)

View File

@@ -48,7 +48,8 @@ class PPTXTranslator(AiTranslator):
logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry, logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry,
system_proxy_enable=config.system_proxy_enable, force_json=config.force_json, system_proxy_enable=config.system_proxy_enable, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -44,7 +44,8 @@ class SrtTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -76,7 +76,8 @@ class TXTTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode

View File

@@ -45,7 +45,8 @@ class XlsxTranslator(AiTranslator):
system_proxy_enable=config.system_proxy_enable, system_proxy_enable=config.system_proxy_enable,
force_json=config.force_json, force_json=config.force_json,
rpm=config.rpm, rpm=config.rpm,
tpm=config.tpm tpm=config.tpm,
provider=config.provider,
) )
self.translate_agent = SegmentsTranslateAgent(agent_config) self.translate_agent = SegmentsTranslateAgent(agent_config)
self.insert_mode = config.insert_mode self.insert_mode = config.insert_mode