增加provider选项
This commit is contained in:
@@ -15,7 +15,7 @@ from urllib.parse import urlparse
|
||||
import httpx
|
||||
import tiktoken
|
||||
|
||||
from docutranslate.agents.thinking.thinking_factory import get_thinking_mode
|
||||
from docutranslate.agents.thinking.thinking_factory import get_thinking_mode, ProviderType
|
||||
from docutranslate.logger import global_logger
|
||||
from docutranslate.utils.utils import get_httpx_proxies
|
||||
|
||||
@@ -55,6 +55,7 @@ class AgentConfig:
|
||||
force_json: bool = False
|
||||
rpm: int | None = None # 每分钟请求数限制
|
||||
tpm: int | None = None # 每分钟Token数限制
|
||||
provider:ProviderType|None=None
|
||||
|
||||
|
||||
class TotalErrorCounter:
|
||||
@@ -281,7 +282,7 @@ class Agent:
|
||||
self.baseurl = config.base_url.strip()
|
||||
if self.baseurl.endswith("/"):
|
||||
self.baseurl = self.baseurl[:-1]
|
||||
self.domain = urlparse(self.baseurl).netloc
|
||||
self.domain = urlparse(self.baseurl).netloc.strip()
|
||||
self.key = config.api_key.strip() if config.api_key else "xx"
|
||||
self.model_id = config.model_id.strip()
|
||||
self.system_prompt = ""
|
||||
@@ -302,6 +303,8 @@ class Agent:
|
||||
# 新增:初始化 encoding 用于估算
|
||||
self.encoding = self._get_encoding_for_model(self.model_id)
|
||||
|
||||
self.provider=config.provider if config.provider is not None else self.domain
|
||||
|
||||
def _get_encoding_for_model(self, model_name: str):
|
||||
"""获取 tiktoken encoding,如果失败则使用 cl100k_base 兜底"""
|
||||
try:
|
||||
@@ -322,7 +325,7 @@ class Agent:
|
||||
return len(text) // 4
|
||||
|
||||
def _add_thinking_mode(self, data: dict):
|
||||
thinking_mode_result = get_thinking_mode(self.domain, data.get("model"))
|
||||
thinking_mode_result = get_thinking_mode(self.provider, data.get("model"))
|
||||
if thinking_mode_result is None:
|
||||
return
|
||||
field_thinking, val_enable, val_disable = thinking_mode_result
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import TypeAlias
|
||||
from typing import TypeAlias, Literal, Any
|
||||
|
||||
mode_type:TypeAlias=str
|
||||
thinking_field: TypeAlias=str
|
||||
enable_value: TypeAlias= str | dict
|
||||
disable_value: TypeAlias= str | dict
|
||||
ProviderType: TypeAlias = Literal["ollama", "open.bigmodel.cn", "dashscope.aliyuncs.com", "ark.cn-beijing.volces.com", "generativelanguage.googleapis.com", "api.siliconflow.cn", "api.302.ai"]|str
|
||||
ModeType: TypeAlias = Literal["ollama", "bigmodel", "aliyun", "volces", "google", "siliconflow", "default"]
|
||||
ThinkingField: TypeAlias = str
|
||||
EnableValueType: TypeAlias = str | dict[str,Any] | bool
|
||||
DisableValueType: TypeAlias = str | dict[str,Any] | bool
|
||||
ThinkingConfig: TypeAlias= tuple[ThinkingField, EnableValueType, DisableValueType]
|
||||
|
||||
|
||||
thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]={
|
||||
thinking_mode: dict[ModeType,ThinkingConfig] = {
|
||||
"ollama": ("reasoning_effort", "medium", "none"),
|
||||
"bigmodel": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
||||
"aliyun": (
|
||||
"extra_body",
|
||||
@@ -36,7 +38,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
|
||||
}
|
||||
|
||||
|
||||
def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str | dict] | None:
|
||||
def get_thinking_mode_by_model_id(model_id: str) -> ThinkingConfig :
|
||||
model_id = model_id.strip().lower()
|
||||
if "glm-4.5" in model_id:
|
||||
return thinking_mode["bigmodel"]
|
||||
@@ -49,8 +51,8 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
|
||||
return thinking_mode["default"]
|
||||
|
||||
|
||||
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
|
||||
provider = provider.strip()
|
||||
def get_thinking_mode(provider: ProviderType, model_id: str) -> ThinkingConfig :
|
||||
provider = provider
|
||||
if provider == "open.bigmodel.cn":
|
||||
return thinking_mode["bigmodel"]
|
||||
elif provider == "dashscope.aliyuncs.com":
|
||||
@@ -63,8 +65,6 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
|
||||
return thinking_mode["siliconflow"]
|
||||
elif provider == "api.302.ai":
|
||||
return get_thinking_mode_by_model_id(model_id)
|
||||
elif provider == "ollama":
|
||||
return thinking_mode["ollama"]
|
||||
return thinking_mode["default"]
|
||||
|
||||
|
||||
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):
|
||||
# pass
|
||||
|
||||
@@ -42,7 +42,8 @@ class AssTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -27,6 +27,7 @@ class AiTranslatorConfig(TranslatorConfig, AgentConfig):
|
||||
glossary_agent_config: GlossaryAgentConfig | None = None
|
||||
skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Document)
|
||||
|
||||
|
||||
@@ -52,23 +53,21 @@ class AiTranslator(Translator[T]):
|
||||
self.glossary_agent = GlossaryAgent(config.glossary_agent_config)
|
||||
else:
|
||||
glossary_agent_config = GlossaryAgentConfig(
|
||||
to_lang=config.to_lang,
|
||||
base_url=config.base_url,
|
||||
api_key=config.api_key,
|
||||
model_id=config.model_id,
|
||||
temperature=config.temperature,
|
||||
thinking=config.thinking,
|
||||
concurrent=config.concurrent,
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
to_lang=config.to_lang, base_url=config.base_url,
|
||||
api_key=config.api_key, model_id=config.model_id, temperature=config.temperature,
|
||||
thinking=config.thinking, concurrent=config.concurrent, timeout=config.timeout,
|
||||
logger=self.logger, retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable, force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.glossary_agent = GlossaryAgent(glossary_agent_config)
|
||||
|
||||
@abstractmethod
|
||||
def translate(self, document: T) -> Document: ...
|
||||
def translate(self, document: T) -> Document:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def translate_async(self, document: T) -> Document: ...
|
||||
async def translate_async(self, document: T) -> Document:
|
||||
...
|
||||
|
||||
@@ -131,7 +131,8 @@ class DocxTranslator(AiTranslator):
|
||||
logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable, force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -53,7 +53,8 @@ class EpubTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -71,7 +71,8 @@ class HtmlTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -39,7 +39,8 @@ class JsonTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.json_paths = config.json_paths
|
||||
|
||||
@@ -42,7 +42,8 @@ class MDTranslator(AiTranslator):
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = MDTranslateAgent(agent_config)
|
||||
|
||||
|
||||
@@ -48,7 +48,8 @@ class PPTXTranslator(AiTranslator):
|
||||
logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable, force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -44,7 +44,8 @@ class SrtTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -76,7 +76,8 @@ class TXTTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -45,7 +45,8 @@ class XlsxTranslator(AiTranslator):
|
||||
system_proxy_enable=config.system_proxy_enable,
|
||||
force_json=config.force_json,
|
||||
rpm=config.rpm,
|
||||
tpm=config.tpm
|
||||
tpm=config.tpm,
|
||||
provider=config.provider,
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
Reference in New Issue
Block a user