增加核心代码、后端对thinking设置的支持
This commit is contained in:
@@ -4,7 +4,9 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
|
||||
from typing import Literal
|
||||
from urllib.parse import urlparse
|
||||
from enum import Enum
|
||||
import httpx
|
||||
|
||||
from docutranslate.logger import global_logger
|
||||
@@ -13,6 +15,8 @@ MAX_RETRY_COUNT = 2
|
||||
MAX_TOTAL_ERROR_COUNT = 10
|
||||
|
||||
|
||||
ThinkingMode=Literal["enable", "disable", "default"]
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class AgentConfig:
|
||||
logger: logging.Logger
|
||||
@@ -23,6 +27,7 @@ class AgentConfig:
|
||||
temperature: float = 0.7
|
||||
max_concurrent: int = 30
|
||||
timeout: int = 2000
|
||||
thinking: ThinkingMode = "default"
|
||||
|
||||
|
||||
class TotalErrorCounter:
|
||||
@@ -62,10 +67,15 @@ TIMEOUT = 600
|
||||
|
||||
|
||||
class Agent:
|
||||
_think_factory = {
|
||||
"open.bigmodel.cn": ("thinking", "enable", "disabled")
|
||||
}
|
||||
|
||||
def __init__(self, config: AgentConfig):
|
||||
self.baseurl = config.baseurl.strip()
|
||||
if self.baseurl.endswith("/"):
|
||||
self.baseurl = self.baseurl[:-1]
|
||||
self.domain = urlparse(self.baseurl).netloc
|
||||
self.key = config.key.strip() or "xx"
|
||||
self.model_id = config.model_id.strip()
|
||||
self.system_prompt = config.system_prompt or ""
|
||||
@@ -74,10 +84,22 @@ class Agent:
|
||||
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
||||
self.max_concurrent = config.max_concurrent
|
||||
self.timeout = config.timeout
|
||||
|
||||
self.thinking = config.thinking
|
||||
self.logger = config.logger or global_logger
|
||||
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
||||
|
||||
def _add_thinking_mode(self, data: dict):
|
||||
if self.domain not in self._think_factory:
|
||||
self.logger.info("尚不支持更改该平台的思考模式")
|
||||
return
|
||||
field_thinking, val_enable, val_disable = self._think_factory[self.domain]
|
||||
if self.thinking == "enable":
|
||||
self.logger.info("使用思考模式")
|
||||
data[field_thinking] = val_enable
|
||||
elif self.thinking == "disable":
|
||||
self.logger.info("关闭思考模式")
|
||||
data[field_thinking] = val_disable
|
||||
|
||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
@@ -93,6 +115,8 @@ class Agent:
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
if self.thinking != "default":
|
||||
self._add_thinking_mode(data)
|
||||
return headers, data
|
||||
|
||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
||||
|
||||
@@ -21,6 +21,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from docutranslate import __version__
|
||||
from docutranslate.agents.agent import ThinkingMode
|
||||
from docutranslate.cacher import md_based_convert_cacher
|
||||
# --- 核心代码 Imports ---
|
||||
from docutranslate.global_values.conditional_import import DOCLING_EXIST
|
||||
@@ -1081,6 +1082,7 @@ async def temp_translate(
|
||||
to_lang: str = Body("中文", description="目标语言。", examples=["中文", "英文", "English"]),
|
||||
concurrent: int = Body(default_params["concurrent"], description="ai翻译请求并发数"),
|
||||
temperature: float = Body(default_params["temperature"], description="ai翻译请求温度"),
|
||||
thinking: ThinkingMode = Body(default_params["thinking"], description="是否启用深度思考", examples=["default", "enable", "disable"]),
|
||||
chunk_size: int = Body(default_params["chunk_size"], description="文本分块大小(bytes)"),
|
||||
custom_prompt: Optional[str] = Body(None, description="翻译自定义提示词",
|
||||
examples=["人名保持原文不翻译"]),
|
||||
@@ -1100,7 +1102,7 @@ async def temp_translate(
|
||||
if isinstance(workflow, MarkdownBasedWorkflow):
|
||||
translator_config = MDTranslatorConfig(
|
||||
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
|
||||
custom_prompt=custom_prompt, temperature=temperature,
|
||||
custom_prompt=custom_prompt, temperature=temperature,thinking=thinking,
|
||||
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
|
||||
)
|
||||
convert_config = ConverterMineruConfig(mineru_token=mineru_token,
|
||||
@@ -1117,7 +1119,7 @@ async def temp_translate(
|
||||
elif isinstance(workflow, TXTWorkflow):
|
||||
translator_config = TXTTranslatorConfig(
|
||||
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
|
||||
custom_prompt=custom_prompt, temperature=temperature,
|
||||
custom_prompt=custom_prompt, temperature=temperature,thinking=thinking,
|
||||
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
default_params = {
|
||||
"thinking":"default",
|
||||
"chunk_size": 3000,
|
||||
"concurrent": 30,
|
||||
"temperature": 0.7,
|
||||
|
||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from typing import TypeVar
|
||||
|
||||
from docutranslate.agents.agent import ThinkingMode
|
||||
from docutranslate.ir.document import Document
|
||||
from docutranslate.translator.base import Translator, TranslatorConfig
|
||||
|
||||
@@ -15,21 +16,27 @@ class AiTranslatorConfig(TranslatorConfig):
|
||||
to_lang: str
|
||||
custom_prompt: str | None = None
|
||||
temperature: float = 0.7
|
||||
thinking: ThinkingMode = "default"
|
||||
timeout: int = 2000
|
||||
chunk_size: int = 3000
|
||||
concurrent: int = 30
|
||||
|
||||
T=TypeVar('T',bound=Document)
|
||||
|
||||
T = TypeVar('T', bound=Document)
|
||||
|
||||
|
||||
class AiTranslator(Translator[T]):
|
||||
"""
|
||||
翻译中间文本(原地替换),Translator不做格式转换
|
||||
"""
|
||||
def __init__(self,config:AiTranslatorConfig,logger:Logger|None=None):
|
||||
super().__init__(config=config,logger=logger)
|
||||
|
||||
def __init__(self, config: AiTranslatorConfig):
|
||||
super().__init__(config=config)
|
||||
|
||||
@abstractmethod
|
||||
def translate(self, document:T) -> Document:
|
||||
def translate(self, document: T) -> Document:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def translate_async(self, document: T) -> Document:
|
||||
...
|
||||
@@ -28,6 +28,7 @@ class MDTranslator(Translator):
|
||||
model_id=config.model_id,
|
||||
system_prompt=None,
|
||||
temperature=config.temperature,
|
||||
thinking=config.thinking,
|
||||
max_concurrent=config.concurrent,
|
||||
timeout=config.timeout,
|
||||
logger=self.logger)
|
||||
|
||||
Reference in New Issue
Block a user