添加token计数,修正阿里云思考模式
This commit is contained in:
@@ -82,6 +82,113 @@ class PromptsCounter:
|
||||
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
|
||||
|
||||
|
||||
def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
从API响应中提取token信息
|
||||
|
||||
支持多种response格式:
|
||||
1. 格式1: usage.input_tokens_details.cached_tokens 和 usage.output_tokens_details.reasoning_tokens
|
||||
2. 格式2: usage.prompt_tokens_details.cached_tokens
|
||||
3. 格式3: usage.prompt_cache_hit_tokens 和 usage.completion_tokens_details.reasoning_tokens
|
||||
|
||||
Args:
|
||||
response_data: API响应数据
|
||||
|
||||
Returns:
|
||||
tuple: (input_tokens, cached_tokens, output_tokens, reasoning_tokens)
|
||||
"""
|
||||
if "usage" not in response_data:
|
||||
return 0, 0, 0, 0
|
||||
|
||||
usage = response_data["usage"]
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# 初始化token详细统计
|
||||
cached_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
|
||||
# 尝试从不同格式获取cached_tokens
|
||||
# 格式1: input_tokens_details.cached_tokens
|
||||
if (
|
||||
"input_tokens_details" in usage
|
||||
and "cached_tokens" in usage["input_tokens_details"]
|
||||
):
|
||||
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
|
||||
# 格式2: prompt_tokens_details.cached_tokens
|
||||
elif (
|
||||
"prompt_tokens_details" in usage
|
||||
and "cached_tokens" in usage["prompt_tokens_details"]
|
||||
):
|
||||
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||||
# 格式3: prompt_cache_hit_tokens (直接在usage下)
|
||||
elif "prompt_cache_hit_tokens" in usage:
|
||||
cached_tokens = usage["prompt_cache_hit_tokens"]
|
||||
|
||||
# 尝试从不同格式获取reasoning_tokens
|
||||
# 格式1: output_tokens_details.reasoning_tokens
|
||||
if (
|
||||
"output_tokens_details" in usage
|
||||
and "reasoning_tokens" in usage["output_tokens_details"]
|
||||
):
|
||||
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
|
||||
# 格式2: completion_tokens_details.reasoning_tokens
|
||||
elif (
|
||||
"completion_tokens_details" in usage
|
||||
and "reasoning_tokens" in usage["completion_tokens_details"]
|
||||
):
|
||||
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
|
||||
|
||||
return input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
def __init__(self, logger: logging.Logger):
|
||||
self.lock = Lock()
|
||||
self.input_tokens = 0
|
||||
self.cached_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.reasoning_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.logger = logger
|
||||
|
||||
def add(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cached_tokens: int,
|
||||
output_tokens: int,
|
||||
reasoning_tokens: int,
|
||||
):
|
||||
with self.lock:
|
||||
self.input_tokens += input_tokens
|
||||
self.cached_tokens += cached_tokens
|
||||
self.output_tokens += output_tokens
|
||||
self.reasoning_tokens += reasoning_tokens
|
||||
self.total_tokens += input_tokens + output_tokens
|
||||
# self.logger.debug(
|
||||
# f"Token使用统计 - 输入: {self.input_tokens}(含cached: {self.cached_tokens}), "
|
||||
# f"输出: {self.output_tokens}(含reasoning: {self.reasoning_tokens}), 总计: {self.total_tokens}"
|
||||
# )
|
||||
|
||||
def get_stats(self):
|
||||
with self.lock:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"cached_tokens": self.cached_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
with self.lock:
|
||||
self.input_tokens = 0
|
||||
self.cached_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.reasoning_tokens = 0
|
||||
self.total_tokens = 0
|
||||
|
||||
|
||||
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
||||
ResultHandlerType = Callable[[str, str, logging.Logger], Any]
|
||||
ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
|
||||
@@ -90,22 +197,30 @@ ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
|
||||
class Agent:
|
||||
_think_factory = {
|
||||
"open.bigmodel.cn": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
||||
"dashscope.aliyuncs.com": ("enable_thinking ", True, False),
|
||||
"ark.cn-beijing.volces.com": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
||||
"generativelanguage.googleapis.com": ("extra_body",
|
||||
{"google": {
|
||||
"thinking_config": {
|
||||
"thinking_budget": -1,
|
||||
"include_thoughts": True
|
||||
"dashscope.aliyuncs.com": (
|
||||
"extra_body",
|
||||
{"enable_thinking": True},
|
||||
{"enable_thinking": False},
|
||||
),
|
||||
"ark.cn-beijing.volces.com": (
|
||||
"thinking",
|
||||
{"type": "enabled"},
|
||||
{"type": "disabled"},
|
||||
),
|
||||
"generativelanguage.googleapis.com": (
|
||||
"extra_body",
|
||||
{
|
||||
"google": {
|
||||
"thinking_config": {"thinking_budget": -1, "include_thoughts": True}
|
||||
}
|
||||
},
|
||||
{
|
||||
"google": {
|
||||
"thinking_config": {"thinking_budget": 0, "include_thoughts": False}
|
||||
}
|
||||
}, {"google": {
|
||||
"thinking_config": {
|
||||
"thinking_budget": 0,
|
||||
"include_thoughts": False
|
||||
}
|
||||
}}),
|
||||
"api.siliconflow.cn": ("enable_thinking", True, False)
|
||||
},
|
||||
),
|
||||
"api.siliconflow.cn": ("enable_thinking", True, False),
|
||||
}
|
||||
|
||||
def __init__(self, config: AgentConfig):
|
||||
@@ -126,6 +241,8 @@ class Agent:
|
||||
# 新增:用于统计最终未解决的错误
|
||||
self.unresolved_error_lock = Lock()
|
||||
self.unresolved_error_count = 0
|
||||
# 新增:用于统计token使用情况
|
||||
self.token_counter = TokenCounter(logger=self.logger)
|
||||
|
||||
def _add_thinking_mode(self, data: dict):
|
||||
if self.domain not in self._think_factory:
|
||||
@@ -136,16 +253,20 @@ class Agent:
|
||||
elif self.thinking == "disable":
|
||||
data[field_thinking] = val_disable
|
||||
|
||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||
def _prepare_request_data(
|
||||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
|
||||
):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
headers = {"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}"}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
}
|
||||
data = {
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
@@ -154,12 +275,18 @@ class Agent:
|
||||
self._add_thinking_mode(data)
|
||||
return headers, data
|
||||
|
||||
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
||||
async def send_async(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
prompt: str,
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
best_partial_result: dict | None = None) -> Any:
|
||||
best_partial_result: dict | None = None,
|
||||
) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if pre_send_handler:
|
||||
@@ -170,21 +297,42 @@ class Agent:
|
||||
should_retry = False
|
||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||
current_partial_result = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# print(f"【测试】resp:\n{response.json()}")
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
# 获取token使用情况
|
||||
response_data = response.json()
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
)
|
||||
|
||||
# 更新token计数器
|
||||
self.token_counter.add(
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||
)
|
||||
|
||||
if retry_count > 0:
|
||||
self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。")
|
||||
self.logger.info(
|
||||
f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。"
|
||||
)
|
||||
|
||||
# print(f"result:=============================================================\n{result}\n================\n")
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
return (
|
||||
result
|
||||
if result_handler is None
|
||||
else result_handler(result, prompt, self.logger)
|
||||
)
|
||||
|
||||
except AgentResultError as e:
|
||||
self.logger.error(f"AI返回结果有误: {e}")
|
||||
@@ -199,7 +347,9 @@ class Agent:
|
||||
|
||||
# 捕获硬错误
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.error(f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}")
|
||||
self.logger.error(
|
||||
f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
should_retry = True
|
||||
is_hard_error = True
|
||||
except httpx.RequestError as e:
|
||||
@@ -220,20 +370,40 @@ class Agent:
|
||||
if retry_count == 0:
|
||||
if self.total_error_counter.add():
|
||||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
return (
|
||||
best_partial_result
|
||||
if best_partial_result
|
||||
else (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
)
|
||||
elif self.total_error_counter.reach_limit():
|
||||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
return (
|
||||
best_partial_result
|
||||
if best_partial_result
|
||||
else (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
)
|
||||
|
||||
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
||||
await asyncio.sleep(0.5)
|
||||
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
return await self.send_async(
|
||||
client,
|
||||
prompt,
|
||||
system_prompt,
|
||||
retry=True,
|
||||
retry_count=retry_count + 1,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
best_partial_result=best_partial_result)
|
||||
best_partial_result=best_partial_result,
|
||||
)
|
||||
else:
|
||||
if should_retry:
|
||||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||||
@@ -245,7 +415,11 @@ class Agent:
|
||||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||||
return best_partial_result
|
||||
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
return (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
|
||||
async def send_prompts_async(
|
||||
self,
|
||||
@@ -254,17 +428,24 @@ class Agent:
|
||||
max_concurrent: int | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
) -> list[Any]:
|
||||
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
|
||||
max_concurrent = (
|
||||
self.max_concurrent if max_concurrent is None else max_concurrent
|
||||
)
|
||||
total = len(prompts)
|
||||
self.logger.info(
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}")
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}"
|
||||
)
|
||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
self.total_error_counter.max_errors_count = (
|
||||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
)
|
||||
|
||||
# 新增:在每次批量发送前重置计数器
|
||||
self.unresolved_error_count = 0
|
||||
# 重置token计数器
|
||||
self.token_counter.reset()
|
||||
|
||||
count = 0
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
@@ -274,10 +455,13 @@ class Agent:
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=self.max_concurrent * 2, # 为重试和并发预留空间
|
||||
max_keepalive_connections=self.max_concurrent # 保持活动的连接数
|
||||
max_keepalive_connections=self.max_concurrent, # 保持活动的连接数
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False, limits=limits) as client:
|
||||
async with httpx.AsyncClient(
|
||||
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||
) as client:
|
||||
|
||||
async def send_with_semaphore(p_text: str):
|
||||
async with semaphore:
|
||||
result = await self.send_async(
|
||||
@@ -300,13 +484,32 @@ class Agent:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 新增:在所有任务完成后打印未解决的错误总数
|
||||
self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}")
|
||||
self.logger.info(
|
||||
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
|
||||
)
|
||||
|
||||
# 新增:打印token使用统计
|
||||
token_stats = self.token_counter.get_stats()
|
||||
self.logger.info(
|
||||
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
|
||||
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
|
||||
f"总计: {token_stats['total_tokens']/1000:.2f}K"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
pre_send_handler=None, result_handler=None, error_result_handler=None,
|
||||
best_partial_result: dict | None = None) -> Any:
|
||||
def send(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
pre_send_handler=None,
|
||||
result_handler=None,
|
||||
error_result_handler=None,
|
||||
best_partial_result: dict | None = None,
|
||||
) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if pre_send_handler:
|
||||
@@ -316,21 +519,41 @@ class Agent:
|
||||
should_retry = False
|
||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||
current_partial_result = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
if retry_count > 0:
|
||||
self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。")
|
||||
# 获取token使用情况
|
||||
response_data = response.json()
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
)
|
||||
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
# 更新token计数器
|
||||
self.token_counter.add(
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||
)
|
||||
|
||||
if retry_count > 0:
|
||||
self.logger.info(
|
||||
f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。"
|
||||
)
|
||||
|
||||
return (
|
||||
result
|
||||
if result_handler is None
|
||||
else result_handler(result, prompt, self.logger)
|
||||
)
|
||||
except AgentResultError as e:
|
||||
self.logger.error(f"AI返回结果有误: {e}")
|
||||
should_retry = True
|
||||
@@ -343,7 +566,9 @@ class Agent:
|
||||
|
||||
# 捕获硬错误
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.error(f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}")
|
||||
self.logger.error(
|
||||
f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
should_retry = True
|
||||
is_hard_error = True
|
||||
except httpx.RequestError as e:
|
||||
@@ -364,20 +589,40 @@ class Agent:
|
||||
if retry_count == 0:
|
||||
if self.total_error_counter.add():
|
||||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
return (
|
||||
best_partial_result
|
||||
if best_partial_result
|
||||
else (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
)
|
||||
elif self.total_error_counter.reach_limit():
|
||||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||||
return best_partial_result if best_partial_result else (
|
||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
||||
return (
|
||||
best_partial_result
|
||||
if best_partial_result
|
||||
else (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
)
|
||||
|
||||
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
||||
time.sleep(0.5)
|
||||
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
return self.send(
|
||||
client,
|
||||
prompt,
|
||||
system_prompt,
|
||||
retry=True,
|
||||
retry_count=retry_count + 1,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
best_partial_result=best_partial_result)
|
||||
best_partial_result=best_partial_result,
|
||||
)
|
||||
else:
|
||||
if should_retry:
|
||||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||||
@@ -389,15 +634,30 @@ class Agent:
|
||||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||||
return best_partial_result
|
||||
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
return (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
|
||||
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
|
||||
def _send_prompt_count(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str,
|
||||
count: PromptsCounter,
|
||||
pre_send_handler,
|
||||
result_handler,
|
||||
error_result_handler) -> Any:
|
||||
result = self.send(client, prompt, system_prompt, pre_send_handler=pre_send_handler,
|
||||
error_result_handler,
|
||||
) -> Any:
|
||||
result = self.send(
|
||||
client,
|
||||
prompt,
|
||||
system_prompt,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler)
|
||||
error_result_handler=error_result_handler,
|
||||
)
|
||||
count.add()
|
||||
return result
|
||||
|
||||
@@ -407,15 +667,22 @@ class Agent:
|
||||
system_prompt: str | None = None,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
) -> list[Any]:
|
||||
self.logger.info(
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}")
|
||||
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
|
||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
|
||||
)
|
||||
self.logger.info(
|
||||
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
|
||||
)
|
||||
self.total_error_counter.max_errors_count = (
|
||||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||
)
|
||||
|
||||
# 新增:在每次批量发送前重置计数器
|
||||
self.unresolved_error_count = 0
|
||||
# 重置token计数器
|
||||
self.token_counter.reset()
|
||||
|
||||
counter = PromptsCounter(len(prompts), self.logger)
|
||||
|
||||
@@ -426,23 +693,41 @@ class Agent:
|
||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||
limits = httpx.Limits(
|
||||
max_connections=self.max_concurrent * 2, # 允许连接复用
|
||||
max_keepalive_connections=self.max_concurrent # 保持活跃连接
|
||||
max_keepalive_connections=self.max_concurrent, # 保持活跃连接
|
||||
)
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
with httpx.Client(trust_env=False, proxies=proxies, verify=False, limits=limits) as client:
|
||||
with httpx.Client(
|
||||
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||
) as client:
|
||||
clients = itertools.repeat(client, len(prompts))
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
||||
results_iterator = executor.map(
|
||||
self._send_prompt_count,
|
||||
clients,
|
||||
prompts,
|
||||
system_prompts,
|
||||
counters,
|
||||
pre_send_handlers,
|
||||
result_handlers,
|
||||
error_result_handlers)
|
||||
error_result_handlers,
|
||||
)
|
||||
output_list = list(results_iterator)
|
||||
|
||||
# 新增:在所有任务完成后打印未解决的错误总数
|
||||
self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}")
|
||||
self.logger.info(
|
||||
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
|
||||
)
|
||||
|
||||
# 新增:打印token使用统计
|
||||
token_stats = self.token_counter.get_stats()
|
||||
self.logger.info(
|
||||
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
|
||||
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
|
||||
f"总计: {token_stats['total_tokens']/1000:.2f}K"
|
||||
)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user