前后端增加system_proxy_enable选项

This commit is contained in:
xunbu
2025-09-24 10:37:42 +08:00
parent f611468a13
commit ac6fcebe24
6 changed files with 1106 additions and 505 deletions

View File

@@ -112,14 +112,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取cached_tokens # 尝试从不同格式获取cached_tokens
# 格式1: input_tokens_details.cached_tokens # 格式1: input_tokens_details.cached_tokens
if ( if (
"input_tokens_details" in usage "input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"] and "cached_tokens" in usage["input_tokens_details"]
): ):
cached_tokens = usage["input_tokens_details"]["cached_tokens"] cached_tokens = usage["input_tokens_details"]["cached_tokens"]
# 格式2: prompt_tokens_details.cached_tokens # 格式2: prompt_tokens_details.cached_tokens
elif ( elif (
"prompt_tokens_details" in usage "prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"] and "cached_tokens" in usage["prompt_tokens_details"]
): ):
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
# 格式3: prompt_cache_hit_tokens (直接在usage下) # 格式3: prompt_cache_hit_tokens (直接在usage下)
@@ -129,14 +129,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取reasoning_tokens # 尝试从不同格式获取reasoning_tokens
# 格式1: output_tokens_details.reasoning_tokens # 格式1: output_tokens_details.reasoning_tokens
if ( if (
"output_tokens_details" in usage "output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"] and "reasoning_tokens" in usage["output_tokens_details"]
): ):
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"] reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
# 格式2: completion_tokens_details.reasoning_tokens # 格式2: completion_tokens_details.reasoning_tokens
elif ( elif (
"completion_tokens_details" in usage "completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"] and "reasoning_tokens" in usage["completion_tokens_details"]
): ):
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
return input_tokens, cached_tokens, output_tokens, reasoning_tokens return input_tokens, cached_tokens, output_tokens, reasoning_tokens
@@ -156,11 +156,11 @@ class TokenCounter:
self.logger = logger self.logger = logger
def add( def add(
self, self,
input_tokens: int, input_tokens: int,
cached_tokens: int, cached_tokens: int,
output_tokens: int, output_tokens: int,
reasoning_tokens: int, reasoning_tokens: int,
): ):
with self.lock: with self.lock:
self.input_tokens += input_tokens self.input_tokens += input_tokens
@@ -249,7 +249,8 @@ class Agent:
self.retry = config.retry self.retry = config.retry
self.system_proxy_enable=config.system_proxy_enable self.system_proxy_enable = config.system_proxy_enable
def _add_thinking_mode(self, data: dict): def _add_thinking_mode(self, data: dict):
if self.domain not in self._think_factory: if self.domain not in self._think_factory:
return return
@@ -260,7 +261,7 @@ class Agent:
data[field_thinking] = val_disable data[field_thinking] = val_disable
def _prepare_request_data( def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
): ):
if temperature is None: if temperature is None:
temperature = self.temperature temperature = self.temperature
@@ -282,16 +283,16 @@ class Agent:
return headers, data return headers, data
async def send_async( async def send_async(
self, self,
client: httpx.AsyncClient, client: httpx.AsyncClient,
prompt: str, prompt: str,
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None, best_partial_result: dict | None = None,
) -> Any: ) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
@@ -432,24 +433,24 @@ class Agent:
) )
async def send_prompts_async( async def send_prompts_async(
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
max_concurrent: int | None = None, max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]: ) -> list[Any]:
max_concurrent = ( max_concurrent = (
self.max_concurrent if max_concurrent is None else max_concurrent self.max_concurrent if max_concurrent is None else max_concurrent
) )
total = len(prompts) total = len(prompts)
self.logger.info( self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}" f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature},system_proxy:{self.system_proxy_enable}"
) )
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count = ( self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR len(prompts) // MAX_REQUESTS_PER_ERROR
) )
# 新增:在每次批量发送前重置计数器 # 新增:在每次批量发送前重置计数器
@@ -469,7 +470,7 @@ class Agent:
) )
async with httpx.AsyncClient( async with httpx.AsyncClient(
trust_env=False, proxies=proxies, verify=False, limits=limits trust_env=False, proxies=proxies, verify=False, limits=limits
) as client: ) as client:
async def send_with_semaphore(p_text: str): async def send_with_semaphore(p_text: str):
@@ -500,7 +501,7 @@ class Agent:
# 新增打印token使用统计 # 新增打印token使用统计
token_stats = self.token_counter.get_stats() token_stats = self.token_counter.get_stats()
if token_stats['input_tokens'] < 0: if token_stats["input_tokens"] < 0:
self.logger.info("Token统计失败") self.logger.info("Token统计失败")
else: else:
self.logger.info( self.logger.info(
@@ -512,16 +513,16 @@ class Agent:
return results return results
def send( def send(
self, self,
client: httpx.Client, client: httpx.Client,
prompt: str, prompt: str,
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
pre_send_handler=None, pre_send_handler=None,
result_handler=None, result_handler=None,
error_result_handler=None, error_result_handler=None,
best_partial_result: dict | None = None, best_partial_result: dict | None = None,
) -> Any: ) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
@@ -658,14 +659,14 @@ class Agent:
) )
def _send_prompt_count( def _send_prompt_count(
self, self,
client: httpx.Client, client: httpx.Client,
prompt: str, prompt: str,
system_prompt: None | str, system_prompt: None | str,
count: PromptsCounter, count: PromptsCounter,
pre_send_handler, pre_send_handler,
result_handler, result_handler,
error_result_handler, error_result_handler,
) -> Any: ) -> Any:
result = self.send( result = self.send(
client, client,
@@ -679,21 +680,21 @@ class Agent:
return result return result
def send_prompts( def send_prompts(
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]: ) -> list[Any]:
self.logger.info( self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}" f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature},system_proxy:{self.system_proxy_enable}"
) )
self.logger.info( self.logger.info(
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}" f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
) )
self.total_error_counter.max_errors_count = ( self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR len(prompts) // MAX_REQUESTS_PER_ERROR
) )
# 新增:在每次批量发送前重置计数器 # 新增:在每次批量发送前重置计数器
@@ -714,7 +715,7 @@ class Agent:
) )
proxies = get_httpx_proxies() if self.system_proxy_enable else None proxies = get_httpx_proxies() if self.system_proxy_enable else None
with httpx.Client( with httpx.Client(
trust_env=False, proxies=proxies, verify=False, limits=limits trust_env=False, proxies=proxies, verify=False, limits=limits
) as client: ) as client:
clients = itertools.repeat(client, len(prompts)) clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
@@ -737,7 +738,7 @@ class Agent:
# 新增打印token使用统计 # 新增打印token使用统计
token_stats = self.token_counter.get_stats() token_stats = self.token_counter.get_stats()
if token_stats['input_tokens'] < 0: if token_stats["input_tokens"] < 0:
self.logger.info("Token统计失败") self.logger.info("Token统计失败")
else: else:
self.logger.info( self.logger.info(

File diff suppressed because it is too large Load Diff

View File

@@ -7,4 +7,4 @@ from .conditional_import import available_packages, conditional_import
USE_PROXY = True if (os.getenv("DOCUTRANSLATE_PROXY_ENABLED") and os.getenv( USE_PROXY = True if (os.getenv("DOCUTRANSLATE_PROXY_ENABLED") and os.getenv(
"DOCUTRANSLATE_PROXY_ENABLED").lower() == "true") else False "DOCUTRANSLATE_PROXY_ENABLED").lower() == "true") else False
if USE_PROXY: if USE_PROXY:
print(f"USE_PROXY:{USE_PROXY}") print(f"USE_PROXY:{USE_PROXY}")

File diff suppressed because one or more lines are too long

View File

@@ -12,9 +12,13 @@ from docutranslate.translator.base import Translator, TranslatorConfig
@dataclass(kw_only=True) @dataclass(kw_only=True)
class AiTranslatorConfig(TranslatorConfig, AgentConfig): class AiTranslatorConfig(TranslatorConfig, AgentConfig):
base_url: str | None = field(default=None, base_url: str | None = field(
metadata={"description": "OpenAI兼容地址当skip_translate为False时为必填项"}) default=None,
model_id: str | None = field(default=None, metadata={"description": "当skip_translate为False时为必填项"}) metadata={"description": "OpenAI兼容地址当skip_translate为False时为必填项"},
)
model_id: str | None = field(
default=None, metadata={"description": "当skip_translate为False时为必填项"}
)
to_lang: str = "简体中文" to_lang: str = "简体中文"
custom_prompt: str | None = None custom_prompt: str | None = None
chunk_size: int = 3000 chunk_size: int = 3000
@@ -24,7 +28,7 @@ class AiTranslatorConfig(TranslatorConfig, AgentConfig):
skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项 skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项
T = TypeVar('T', bound=Document) T = TypeVar("T", bound=Document)
class AiTranslator(Translator[T]): class AiTranslator(Translator[T]):
@@ -37,8 +41,12 @@ class AiTranslator(Translator[T]):
self.skip_translate = config.skip_translate self.skip_translate = config.skip_translate
self.glossary_agent = None self.glossary_agent = None
self.glossary_dict_gen = None self.glossary_dict_gen = None
if not self.skip_translate and (config.base_url is None or config.api_key is None or config.model_id is None): if not self.skip_translate and (
raise ValueError("skip_translate不为false时base_url、api_key、model_id为必填项") config.base_url is None or config.api_key is None or config.model_id is None
):
raise ValueError(
"skip_translate不为false时base_url、api_key、model_id为必填项"
)
if config.glossary_generate_enable: if config.glossary_generate_enable:
if config.glossary_agent_config: if config.glossary_agent_config:
@@ -54,14 +62,13 @@ class AiTranslator(Translator[T]):
concurrent=config.concurrent, concurrent=config.concurrent,
timeout=config.timeout, timeout=config.timeout,
logger=self.logger, logger=self.logger,
retry=config.retry retry=config.retry,
system_proxy_enable=config.system_proxy_enable,
) )
self.glossary_agent = GlossaryAgent(glossary_agent_config) self.glossary_agent = GlossaryAgent(glossary_agent_config)
@abstractmethod @abstractmethod
def translate(self, document: T) -> Document: def translate(self, document: T) -> Document: ...
...
@abstractmethod @abstractmethod
async def translate_async(self, document: T) -> Document: async def translate_async(self, document: T) -> Document: ...
...

View File

@@ -12,3 +12,6 @@ def get_httpx_proxies():
if http_proxy: if http_proxy:
proxies["http://"] = http_proxy proxies["http://"] = http_proxy
return proxies return proxies
if __name__ == '__main__':
print(get_httpx_proxies())