agent增加system_proxy_enable参数
This commit is contained in:
@@ -48,6 +48,7 @@ class AgentConfig:
|
||||
timeout: int = 1200 # 单位(秒),这个值是httpx.TimeOut中read的值,并非总的超时时间
|
||||
thinking: ThinkingMode = "disable"
|
||||
retry: int = 2
|
||||
system_proxy_enable: bool = USE_PROXY
|
||||
|
||||
|
||||
class TotalErrorCounter:
|
||||
@@ -111,14 +112,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
|
||||
# 尝试从不同格式获取cached_tokens
|
||||
# 格式1: input_tokens_details.cached_tokens
|
||||
if (
|
||||
"input_tokens_details" in usage
|
||||
and "cached_tokens" in usage["input_tokens_details"]
|
||||
"input_tokens_details" in usage
|
||||
and "cached_tokens" in usage["input_tokens_details"]
|
||||
):
|
||||
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
|
||||
# 格式2: prompt_tokens_details.cached_tokens
|
||||
elif (
|
||||
"prompt_tokens_details" in usage
|
||||
and "cached_tokens" in usage["prompt_tokens_details"]
|
||||
"prompt_tokens_details" in usage
|
||||
and "cached_tokens" in usage["prompt_tokens_details"]
|
||||
):
|
||||
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||||
# 格式3: prompt_cache_hit_tokens (直接在usage下)
|
||||
@@ -128,20 +129,20 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
|
||||
# 尝试从不同格式获取reasoning_tokens
|
||||
# 格式1: output_tokens_details.reasoning_tokens
|
||||
if (
|
||||
"output_tokens_details" in usage
|
||||
and "reasoning_tokens" in usage["output_tokens_details"]
|
||||
"output_tokens_details" in usage
|
||||
and "reasoning_tokens" in usage["output_tokens_details"]
|
||||
):
|
||||
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
|
||||
# 格式2: completion_tokens_details.reasoning_tokens
|
||||
elif (
|
||||
"completion_tokens_details" in usage
|
||||
and "reasoning_tokens" in usage["completion_tokens_details"]
|
||||
"completion_tokens_details" in usage
|
||||
and "reasoning_tokens" in usage["completion_tokens_details"]
|
||||
):
|
||||
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
|
||||
return input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||
except TypeError as e:
|
||||
print(f"获取token发生错误:{e.__repr__()}")
|
||||
return -1,-1,-1,-1
|
||||
return -1, -1, -1, -1
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
@@ -155,11 +156,11 @@ class TokenCounter:
|
||||
self.logger = logger
|
||||
|
||||
def add(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cached_tokens: int,
|
||||
output_tokens: int,
|
||||
reasoning_tokens: int,
|
||||
self,
|
||||
input_tokens: int,
|
||||
cached_tokens: int,
|
||||
output_tokens: int,
|
||||
reasoning_tokens: int,
|
||||
):
|
||||
with self.lock:
|
||||
self.input_tokens += input_tokens
|
||||
@@ -248,6 +249,7 @@ class Agent:
|
||||
|
||||
self.retry = config.retry
|
||||
|
||||
self.system_proxy_enable=config.system_proxy_enable
|
||||
def _add_thinking_mode(self, data: dict):
|
||||
if self.domain not in self._think_factory:
|
||||
return
|
||||
@@ -258,7 +260,7 @@ class Agent:
|
||||
data[field_thinking] = val_disable
|
||||
|
||||
def _prepare_request_data(
|
||||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
|
||||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
|
||||
):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
@@ -280,16 +282,16 @@ class Agent:
|
||||
return headers, data
|
||||
|
||||
async def send_async(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
prompt: str,
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
best_partial_result: dict | None = None,
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
prompt: str,
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
best_partial_result: dict | None = None,
|
||||
) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
@@ -430,13 +432,13 @@ class Agent:
|
||||
)
|
||||
|
||||
async def send_prompts_async(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
max_concurrent: int | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
max_concurrent: int | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
) -> list[Any]:
|
||||
max_concurrent = (
|
||||
self.max_concurrent if max_concurrent is None else max_concurrent
|
||||
@@ -447,7 +449,7 @@ class Agent:
|
||||
)
|
||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||
self.total_error_counter.max_errors_count = (
|
||||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
)
|
||||
|
||||
# 新增:在每次批量发送前重置计数器
|
||||
@@ -459,7 +461,7 @@ class Agent:
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
tasks = []
|
||||
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
proxies = get_httpx_proxies() if self.system_proxy_enable else None
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=self.max_concurrent * 2, # 为重试和并发预留空间
|
||||
@@ -467,7 +469,7 @@ class Agent:
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||
) as client:
|
||||
|
||||
async def send_with_semaphore(p_text: str):
|
||||
@@ -498,7 +500,7 @@ class Agent:
|
||||
|
||||
# 新增:打印token使用统计
|
||||
token_stats = self.token_counter.get_stats()
|
||||
if token_stats['input_tokens']<0:
|
||||
if token_stats['input_tokens'] < 0:
|
||||
self.logger.info("Token统计失败")
|
||||
else:
|
||||
self.logger.info(
|
||||
@@ -510,16 +512,16 @@ class Agent:
|
||||
return results
|
||||
|
||||
def send(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler=None,
|
||||
result_handler=None,
|
||||
error_result_handler=None,
|
||||
best_partial_result: dict | None = None,
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler=None,
|
||||
result_handler=None,
|
||||
error_result_handler=None,
|
||||
best_partial_result: dict | None = None,
|
||||
) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
@@ -656,14 +658,14 @@ class Agent:
|
||||
)
|
||||
|
||||
def _send_prompt_count(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str,
|
||||
count: PromptsCounter,
|
||||
pre_send_handler,
|
||||
result_handler,
|
||||
error_result_handler,
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str,
|
||||
count: PromptsCounter,
|
||||
pre_send_handler,
|
||||
result_handler,
|
||||
error_result_handler,
|
||||
) -> Any:
|
||||
result = self.send(
|
||||
client,
|
||||
@@ -677,12 +679,12 @@ class Agent:
|
||||
return result
|
||||
|
||||
def send_prompts(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
) -> list[Any]:
|
||||
self.logger.info(
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
|
||||
@@ -691,7 +693,7 @@ class Agent:
|
||||
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
|
||||
)
|
||||
self.total_error_counter.max_errors_count = (
|
||||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
)
|
||||
|
||||
# 新增:在每次批量发送前重置计数器
|
||||
@@ -710,9 +712,9 @@ class Agent:
|
||||
max_connections=self.max_concurrent * 2, # 允许连接复用
|
||||
max_keepalive_connections=self.max_concurrent, # 保持活跃连接
|
||||
)
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
proxies = get_httpx_proxies() if self.system_proxy_enable else None
|
||||
with httpx.Client(
|
||||
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||
) as client:
|
||||
clients = itertools.repeat(client, len(prompts))
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
|
||||
@@ -38,7 +38,8 @@ class AssTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -53,7 +53,8 @@ class DocxTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -44,7 +44,8 @@ class EpubTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -101,7 +101,8 @@ class HtmlTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -34,7 +34,8 @@ class JsonTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.json_paths = config.json_paths
|
||||
|
||||
@@ -34,7 +34,8 @@ class MDTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry)
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable)
|
||||
self.translate_agent = MDTranslateAgent(agent_config)
|
||||
|
||||
def translate(self, document: MarkdownDocument) -> Self:
|
||||
|
||||
@@ -40,7 +40,8 @@ class SrtTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -58,7 +58,8 @@ class TXTTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
@@ -42,7 +42,8 @@ class XlsxTranslator(AiTranslator):
|
||||
timeout=config.timeout,
|
||||
logger=self.logger,
|
||||
glossary_dict=config.glossary_dict,
|
||||
retry=config.retry
|
||||
retry=config.retry,
|
||||
system_proxy_enable=config.system_proxy_enable
|
||||
)
|
||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||
self.insert_mode = config.insert_mode
|
||||
|
||||
Reference in New Issue
Block a user