改进provider
This commit is contained in:
@@ -15,6 +15,7 @@ from urllib.parse import urlparse
|
||||
import httpx
|
||||
import tiktoken
|
||||
|
||||
from docutranslate.agents.provider import get_provider_by_domain
|
||||
from docutranslate.agents.thinking.thinking_factory import get_thinking_mode, ProviderType
|
||||
from docutranslate.logger import global_logger
|
||||
from docutranslate.utils.utils import get_httpx_proxies
|
||||
@@ -303,7 +304,7 @@ class Agent:
|
||||
# 新增:初始化 encoding 用于估算
|
||||
self.encoding = self._get_encoding_for_model(self.model_id)
|
||||
|
||||
self.provider=config.provider if config.provider is not None else self.domain
|
||||
self.provider=config.provider if config.provider is not None else get_provider_by_domain(self.domain)
|
||||
|
||||
def _get_encoding_for_model(self, model_name: str):
|
||||
"""获取 tiktoken encoding,如果失败则使用 cl100k_base 兜底"""
|
||||
|
||||
1
docutranslate/agents/provider/__init__.py
Normal file
1
docutranslate/agents/provider/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .provider import get_provider_by_domain,ProviderType
|
||||
16
docutranslate/agents/provider/provider.py
Normal file
16
docutranslate/agents/provider/provider.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import TypeAlias, Literal
|
||||
|
||||
ProviderType: TypeAlias = Literal["ollama", "bigmodel", "aliyun", "volces", "google", "siliconflow", "default"]
|
||||
|
||||
def get_provider_by_domain(domain:str)->ProviderType:
|
||||
if domain == "open.bigmodel.cn":
|
||||
return "bigmodel"
|
||||
elif domain == "dashscope.aliyuncs.com":
|
||||
return "aliyun"
|
||||
elif domain == "ark.cn-beijing.volces.com":
|
||||
return "volces"
|
||||
elif domain == "generativelanguage.googleapis.com":
|
||||
return "google"
|
||||
elif domain == "api.siliconflow.cn":
|
||||
return "siliconflow"
|
||||
return "default"
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import TypeAlias, Literal, Any
|
||||
|
||||
ProviderType: TypeAlias = Literal["ollama", "open.bigmodel.cn", "dashscope.aliyuncs.com", "ark.cn-beijing.volces.com", "generativelanguage.googleapis.com", "api.siliconflow.cn", "api.302.ai","api.openai.com"]|str
|
||||
from docutranslate.agents.provider import ProviderType
|
||||
|
||||
ModeType: TypeAlias = Literal["ollama", "bigmodel", "aliyun", "volces", "google", "siliconflow", "default"]
|
||||
ThinkingField: TypeAlias = str
|
||||
EnableValueType: TypeAlias = str | dict[str,Any] | bool
|
||||
DisableValueType: TypeAlias = str | dict[str,Any] | bool
|
||||
ThinkingConfig: TypeAlias= tuple[ThinkingField, EnableValueType, DisableValueType]
|
||||
|
||||
thinking_mode: dict[ModeType,ThinkingConfig] = {
|
||||
thinking_mode: dict[ProviderType,ThinkingConfig] = {
|
||||
"ollama": ("reasoning_effort", "medium", "none"),
|
||||
"bigmodel": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
||||
"aliyun": (
|
||||
@@ -53,18 +54,16 @@ def get_thinking_mode_by_model_id(model_id: str) -> ThinkingConfig :
|
||||
|
||||
def get_thinking_mode(provider: ProviderType, model_id: str) -> ThinkingConfig :
|
||||
provider = provider
|
||||
if provider == "open.bigmodel.cn":
|
||||
if provider == "bigmodel":
|
||||
return thinking_mode["bigmodel"]
|
||||
elif provider == "dashscope.aliyuncs.com":
|
||||
elif provider == "aliyun":
|
||||
return thinking_mode["aliyun"]
|
||||
elif provider == "ark.cn-beijing.volces.com":
|
||||
elif provider == "volces":
|
||||
return thinking_mode["volces"]
|
||||
elif provider == "generativelanguage.googleapis.com":
|
||||
elif provider == "google":
|
||||
return thinking_mode["google"]
|
||||
elif provider == "api.siliconflow.cn":
|
||||
elif provider == "siliconflow":
|
||||
return thinking_mode["siliconflow"]
|
||||
elif provider == "api.302.ai":
|
||||
return get_thinking_mode_by_model_id(model_id)
|
||||
elif provider == "ollama":
|
||||
return thinking_mode["ollama"]
|
||||
return thinking_mode["default"]
|
||||
return get_thinking_mode_by_model_id(model_id)
|
||||
@@ -317,18 +317,7 @@ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
# ===================================================================
|
||||
# --- Pydantic Models for Service API ---
|
||||
# ===================================================================
|
||||
|
||||
ProviderType: TypeAlias = Literal[
|
||||
"ollama",
|
||||
"open.bigmodel.cn",
|
||||
"dashscope.aliyuncs.com",
|
||||
"ark.cn-beijing.volces.com",
|
||||
"generativelanguage.googleapis.com",
|
||||
"api.siliconflow.cn",
|
||||
"api.302.ai"
|
||||
] | str
|
||||
|
||||
# =================================================================
|
||||
|
||||
|
||||
# 4. 创建最终的请求体模型
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user