增加核心代码、后端对thinking设置的支持
This commit is contained in:
@@ -4,7 +4,9 @@ import time
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from typing import Literal
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from enum import Enum
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from docutranslate.logger import global_logger
|
from docutranslate.logger import global_logger
|
||||||
@@ -13,6 +15,8 @@ MAX_RETRY_COUNT = 2
|
|||||||
MAX_TOTAL_ERROR_COUNT = 10
|
MAX_TOTAL_ERROR_COUNT = 10
|
||||||
|
|
||||||
|
|
||||||
|
ThinkingMode=Literal["enable", "disable", "default"]
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class AgentConfig:
|
class AgentConfig:
|
||||||
logger: logging.Logger
|
logger: logging.Logger
|
||||||
@@ -23,6 +27,7 @@ class AgentConfig:
|
|||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
max_concurrent: int = 30
|
max_concurrent: int = 30
|
||||||
timeout: int = 2000
|
timeout: int = 2000
|
||||||
|
thinking: ThinkingMode = "default"
|
||||||
|
|
||||||
|
|
||||||
class TotalErrorCounter:
|
class TotalErrorCounter:
|
||||||
@@ -62,10 +67,15 @@ TIMEOUT = 600
|
|||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
|
_think_factory = {
|
||||||
|
"open.bigmodel.cn": ("thinking", "enable", "disabled")
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: AgentConfig):
|
def __init__(self, config: AgentConfig):
|
||||||
self.baseurl = config.baseurl.strip()
|
self.baseurl = config.baseurl.strip()
|
||||||
if self.baseurl.endswith("/"):
|
if self.baseurl.endswith("/"):
|
||||||
self.baseurl = self.baseurl[:-1]
|
self.baseurl = self.baseurl[:-1]
|
||||||
|
self.domain = urlparse(self.baseurl).netloc
|
||||||
self.key = config.key.strip() or "xx"
|
self.key = config.key.strip() or "xx"
|
||||||
self.model_id = config.model_id.strip()
|
self.model_id = config.model_id.strip()
|
||||||
self.system_prompt = config.system_prompt or ""
|
self.system_prompt = config.system_prompt or ""
|
||||||
@@ -74,10 +84,22 @@ class Agent:
|
|||||||
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
||||||
self.max_concurrent = config.max_concurrent
|
self.max_concurrent = config.max_concurrent
|
||||||
self.timeout = config.timeout
|
self.timeout = config.timeout
|
||||||
|
self.thinking = config.thinking
|
||||||
self.logger = config.logger or global_logger
|
self.logger = config.logger or global_logger
|
||||||
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
||||||
|
|
||||||
|
def _add_thinking_mode(self, data: dict):
|
||||||
|
if self.domain not in self._think_factory:
|
||||||
|
self.logger.info("尚不支持更改该平台的思考模式")
|
||||||
|
return
|
||||||
|
field_thinking, val_enable, val_disable = self._think_factory[self.domain]
|
||||||
|
if self.thinking == "enable":
|
||||||
|
self.logger.info("使用思考模式")
|
||||||
|
data[field_thinking] = val_enable
|
||||||
|
elif self.thinking == "disable":
|
||||||
|
self.logger.info("关闭思考模式")
|
||||||
|
data[field_thinking] = val_disable
|
||||||
|
|
||||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = self.temperature
|
temperature = self.temperature
|
||||||
@@ -93,6 +115,8 @@ class Agent:
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
}
|
}
|
||||||
|
if self.thinking != "default":
|
||||||
|
self._add_thinking_mode(data)
|
||||||
return headers, data
|
return headers, data
|
||||||
|
|
||||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from docutranslate import __version__
|
from docutranslate import __version__
|
||||||
|
from docutranslate.agents.agent import ThinkingMode
|
||||||
from docutranslate.cacher import md_based_convert_cacher
|
from docutranslate.cacher import md_based_convert_cacher
|
||||||
# --- 核心代码 Imports ---
|
# --- 核心代码 Imports ---
|
||||||
from docutranslate.global_values.conditional_import import DOCLING_EXIST
|
from docutranslate.global_values.conditional_import import DOCLING_EXIST
|
||||||
@@ -1081,6 +1082,7 @@ async def temp_translate(
|
|||||||
to_lang: str = Body("中文", description="目标语言。", examples=["中文", "英文", "English"]),
|
to_lang: str = Body("中文", description="目标语言。", examples=["中文", "英文", "English"]),
|
||||||
concurrent: int = Body(default_params["concurrent"], description="ai翻译请求并发数"),
|
concurrent: int = Body(default_params["concurrent"], description="ai翻译请求并发数"),
|
||||||
temperature: float = Body(default_params["temperature"], description="ai翻译请求温度"),
|
temperature: float = Body(default_params["temperature"], description="ai翻译请求温度"),
|
||||||
|
thinking: ThinkingMode = Body(default_params["thinking"], description="是否启用深度思考", examples=["default", "enable", "disable"]),
|
||||||
chunk_size: int = Body(default_params["chunk_size"], description="文本分块大小(bytes)"),
|
chunk_size: int = Body(default_params["chunk_size"], description="文本分块大小(bytes)"),
|
||||||
custom_prompt: Optional[str] = Body(None, description="翻译自定义提示词",
|
custom_prompt: Optional[str] = Body(None, description="翻译自定义提示词",
|
||||||
examples=["人名保持原文不翻译"]),
|
examples=["人名保持原文不翻译"]),
|
||||||
@@ -1100,7 +1102,7 @@ async def temp_translate(
|
|||||||
if isinstance(workflow, MarkdownBasedWorkflow):
|
if isinstance(workflow, MarkdownBasedWorkflow):
|
||||||
translator_config = MDTranslatorConfig(
|
translator_config = MDTranslatorConfig(
|
||||||
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
|
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
|
||||||
custom_prompt=custom_prompt, temperature=temperature,
|
custom_prompt=custom_prompt, temperature=temperature,thinking=thinking,
|
||||||
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
|
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
|
||||||
)
|
)
|
||||||
convert_config = ConverterMineruConfig(mineru_token=mineru_token,
|
convert_config = ConverterMineruConfig(mineru_token=mineru_token,
|
||||||
@@ -1117,7 +1119,7 @@ async def temp_translate(
|
|||||||
elif isinstance(workflow, TXTWorkflow):
|
elif isinstance(workflow, TXTWorkflow):
|
||||||
translator_config = TXTTranslatorConfig(
|
translator_config = TXTTranslatorConfig(
|
||||||
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
|
base_url=base_url, api_key=api_key, model_id=model_id, to_lang=to_lang,
|
||||||
custom_prompt=custom_prompt, temperature=temperature,
|
custom_prompt=custom_prompt, temperature=temperature,thinking=thinking,
|
||||||
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
|
chunk_size=chunk_size, concurrent=concurrent, logger=global_logger, timeout=2000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
default_params = {
|
default_params = {
|
||||||
|
"thinking":"default",
|
||||||
"chunk_size": 3000,
|
"chunk_size": 3000,
|
||||||
"concurrent": 30,
|
"concurrent": 30,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from docutranslate.agents.agent import ThinkingMode
|
||||||
from docutranslate.ir.document import Document
|
from docutranslate.ir.document import Document
|
||||||
from docutranslate.translator.base import Translator, TranslatorConfig
|
from docutranslate.translator.base import Translator, TranslatorConfig
|
||||||
|
|
||||||
@@ -15,21 +16,27 @@ class AiTranslatorConfig(TranslatorConfig):
|
|||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: str | None = None
|
custom_prompt: str | None = None
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
|
thinking: ThinkingMode = "default"
|
||||||
timeout: int = 2000
|
timeout: int = 2000
|
||||||
chunk_size: int = 3000
|
chunk_size: int = 3000
|
||||||
concurrent: int = 30
|
concurrent: int = 30
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=Document)
|
T = TypeVar('T', bound=Document)
|
||||||
|
|
||||||
|
|
||||||
class AiTranslator(Translator[T]):
|
class AiTranslator(Translator[T]):
|
||||||
"""
|
"""
|
||||||
翻译中间文本(原地替换),Translator不做格式转换
|
翻译中间文本(原地替换),Translator不做格式转换
|
||||||
"""
|
"""
|
||||||
def __init__(self,config:AiTranslatorConfig,logger:Logger|None=None):
|
|
||||||
super().__init__(config=config,logger=logger)
|
def __init__(self, config: AiTranslatorConfig):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def translate(self, document: T) -> Document:
|
def translate(self, document: T) -> Document:
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def translate_async(self, document: T) -> Document:
|
async def translate_async(self, document: T) -> Document:
|
||||||
...
|
...
|
||||||
@@ -28,6 +28,7 @@ class MDTranslator(Translator):
|
|||||||
model_id=config.model_id,
|
model_id=config.model_id,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
|
thinking=config.thinking,
|
||||||
max_concurrent=config.concurrent,
|
max_concurrent=config.concurrent,
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
logger=self.logger)
|
logger=self.logger)
|
||||||
|
|||||||
Reference in New Issue
Block a user