增加核心代码、后端对thinking设置的支持

This commit is contained in:
xunbu
2025-08-01 14:21:18 +08:00
parent 190ba01430
commit e31a36bb93
5 changed files with 45 additions and 10 deletions

View File

@@ -4,7 +4,9 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import Literal
from urllib.parse import urlparse
from enum import Enum
import httpx import httpx
from docutranslate.logger import global_logger from docutranslate.logger import global_logger
@@ -13,6 +15,8 @@ MAX_RETRY_COUNT = 2
MAX_TOTAL_ERROR_COUNT = 10 MAX_TOTAL_ERROR_COUNT = 10
ThinkingMode=Literal["enable", "disable", "default"]
@dataclass(kw_only=True) @dataclass(kw_only=True)
class AgentConfig: class AgentConfig:
logger: logging.Logger logger: logging.Logger
@@ -23,6 +27,7 @@ class AgentConfig:
temperature: float = 0.7 temperature: float = 0.7
max_concurrent: int = 30 max_concurrent: int = 30
timeout: int = 2000 timeout: int = 2000
thinking: ThinkingMode = "default"
class TotalErrorCounter: class TotalErrorCounter:
@@ -62,10 +67,15 @@ TIMEOUT = 600
class Agent: class Agent:
_think_factory = {
"open.bigmodel.cn": ("thinking", "enable", "disabled")
}
def __init__(self, config: AgentConfig): def __init__(self, config: AgentConfig):
self.baseurl = config.baseurl.strip() self.baseurl = config.baseurl.strip()
if self.baseurl.endswith("/"): if self.baseurl.endswith("/"):
self.baseurl = self.baseurl[:-1] self.baseurl = self.baseurl[:-1]
self.domain = urlparse(self.baseurl).netloc
self.key = config.key.strip() or "xx" self.key = config.key.strip() or "xx"
self.model_id = config.model_id.strip() self.model_id = config.model_id.strip()
self.system_prompt = config.system_prompt or "" self.system_prompt = config.system_prompt or ""
@@ -74,10 +84,22 @@ class Agent:
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False) self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
self.max_concurrent = config.max_concurrent self.max_concurrent = config.max_concurrent
self.timeout = config.timeout self.timeout = config.timeout
self.thinking = config.thinking
self.logger = config.logger or global_logger self.logger = config.logger or global_logger
self.total_error_counter = TotalErrorCounter(logger=self.logger) self.total_error_counter = TotalErrorCounter(logger=self.logger)
def _add_thinking_mode(self, data: dict):
if self.domain not in self._think_factory:
self.logger.info("尚不支持更改该平台的思考模式")
return
field_thinking, val_enable, val_disable = self._think_factory[self.domain]
if self.thinking == "enable":
self.logger.info("使用思考模式")
data[field_thinking] = val_enable
elif self.thinking == "disable":
self.logger.info("关闭思考模式")
data[field_thinking] = val_disable
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
if temperature is None: if temperature is None:
temperature = self.temperature temperature = self.temperature
@@ -93,6 +115,8 @@ class Agent:
"temperature": temperature, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
} }
if self.thinking != "default":
self._add_thinking_mode(data)
return headers, data return headers, data
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str: async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:

View File

@@ -21,6 +21,7 @@ from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from docutranslate import __version__ from docutranslate import __version__
from docutranslate.agents.agent import ThinkingMode
from docutranslate.cacher import md_based_convert_cacher from docutranslate.cacher import md_based_convert_cacher
# --- 核心代码 Imports --- # --- 核心代码 Imports ---
from docutranslate.global_values.conditional_import import DOCLING_EXIST from docutranslate.global_values.conditional_import import DOCLING_EXIST
@@ -1081,6 +1082,7 @@ async def temp_translate(
to_lang: str = Body("中文", description="目标语言。", examples=["中文", "英文", "English"]), to_lang: str = Body("中文", description="目标语言。", examples=["中文", "英文", "English"]),
concurrent: int = Body(default_params["concurrent"], description="ai翻译请求并发数"), concurrent: int = Body(default_params["concurrent"], description="ai翻译请求并发数"),
temperature: float = Body(default_params["temperature"], description="ai翻译请求温度"), temperature: float = Body(default_params["temperature"], description="ai翻译请求温度"),
thinking: ThinkingMode = Body(default_params["thinking"], description="是否启用深度思考", examples=["default", "enable", "disable"]),
chunk_size: int = Body(default_params["chunk_size"], description="文本分块大小bytes"), chunk_size: int = Body(default_params["chunk_size"], description="文本分块大小bytes"),
custom_prompt: Optional[str] = Body(None, description="翻译自定义提示词", custom_prompt: Optional[str] = Body(None, description="翻译自定义提示词",
examples=["人名保持原文不翻译"]), examples=["人名保持原文不翻译"]),
@@ -1100,7 +1102,7 @@ async def temp_translate(
if isinstance(workflow, MarkdownBasedWorkflow): if isinstance(workflow, MarkdownBasedWorkflow):
translator_config = MDTranslatorConfig( translator_config = MDTranslatorConfig(
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang, base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
custom_prompt=custom_prompt, temperature=temperature, custom_prompt=custom_prompt, temperature=temperature,thinking=thinking,
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000 chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
) )
convert_config = ConverterMineruConfig(mineru_token=mineru_token, convert_config = ConverterMineruConfig(mineru_token=mineru_token,
@@ -1117,7 +1119,7 @@ async def temp_translate(
elif isinstance(workflow, TXTWorkflow): elif isinstance(workflow, TXTWorkflow):
translator_config = TXTTranslatorConfig( translator_config = TXTTranslatorConfig(
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang, base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
custom_prompt=custom_prompt, temperature=temperature, custom_prompt=custom_prompt, temperature=temperature,thinking=thinking,
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000 chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
) )

View File

@@ -1,4 +1,5 @@
default_params = { default_params = {
"thinking":"default",
"chunk_size": 3000, "chunk_size": 3000,
"concurrent": 30, "concurrent": 30,
"temperature": 0.7, "temperature": 0.7,

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass
from logging import Logger from logging import Logger
from typing import TypeVar from typing import TypeVar
from docutranslate.agents.agent import ThinkingMode
from docutranslate.ir.document import Document from docutranslate.ir.document import Document
from docutranslate.translator.base import Translator, TranslatorConfig from docutranslate.translator.base import Translator, TranslatorConfig
@@ -15,21 +16,27 @@ class AiTranslatorConfig(TranslatorConfig):
to_lang: str to_lang: str
custom_prompt: str | None = None custom_prompt: str | None = None
temperature: float = 0.7 temperature: float = 0.7
thinking: ThinkingMode = "default"
timeout: int = 2000 timeout: int = 2000
chunk_size: int = 3000 chunk_size: int = 3000
concurrent: int = 30 concurrent: int = 30
T=TypeVar('T',bound=Document)
T = TypeVar('T', bound=Document)
class AiTranslator(Translator[T]): class AiTranslator(Translator[T]):
""" """
翻译中间文本原地替换Translator不做格式转换 翻译中间文本原地替换Translator不做格式转换
""" """
def __init__(self,config:AiTranslatorConfig,logger:Logger|None=None):
super().__init__(config=config,logger=logger) def __init__(self, config: AiTranslatorConfig):
super().__init__(config=config)
@abstractmethod @abstractmethod
def translate(self, document:T) -> Document: def translate(self, document: T) -> Document:
... ...
@abstractmethod @abstractmethod
async def translate_async(self, document: T) -> Document: async def translate_async(self, document: T) -> Document:
... ...

View File

@@ -28,6 +28,7 @@ class MDTranslator(Translator):
model_id=config.model_id, model_id=config.model_id,
system_prompt=None, system_prompt=None,
temperature=config.temperature, temperature=config.temperature,
thinking=config.thinking,
max_concurrent=config.concurrent, max_concurrent=config.concurrent,
timeout=config.timeout, timeout=config.timeout,
logger=self.logger) logger=self.logger)