添加token计数,修正阿里云思考模式

This commit is contained in:
xunbu
2025-09-08 15:59:06 +08:00
parent 345bb4f430
commit 787009dcaa

View File

@@ -82,6 +82,113 @@ class PromptsCounter:
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
"""
从API响应中提取token信息
支持多种response格式:
1. 格式1: usage.input_tokens_details.cached_tokens 和 usage.output_tokens_details.reasoning_tokens
2. 格式2: usage.prompt_tokens_details.cached_tokens
3. 格式3: usage.prompt_cache_hit_tokens 和 usage.completion_tokens_details.reasoning_tokens
Args:
response_data: API响应数据
Returns:
tuple: (input_tokens, cached_tokens, output_tokens, reasoning_tokens)
"""
if "usage" not in response_data:
return 0, 0, 0, 0
usage = response_data["usage"]
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
# 初始化token详细统计
cached_tokens = 0
reasoning_tokens = 0
# 尝试从不同格式获取cached_tokens
# 格式1: input_tokens_details.cached_tokens
if (
"input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"]
):
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
# 格式2: prompt_tokens_details.cached_tokens
elif (
"prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"]
):
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
# 格式3: prompt_cache_hit_tokens (直接在usage下)
elif "prompt_cache_hit_tokens" in usage:
cached_tokens = usage["prompt_cache_hit_tokens"]
# 尝试从不同格式获取reasoning_tokens
# 格式1: output_tokens_details.reasoning_tokens
if (
"output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"]
):
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
# 格式2: completion_tokens_details.reasoning_tokens
elif (
"completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"]
):
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
return input_tokens, cached_tokens, output_tokens, reasoning_tokens
class TokenCounter:
def __init__(self, logger: logging.Logger):
self.lock = Lock()
self.input_tokens = 0
self.cached_tokens = 0
self.output_tokens = 0
self.reasoning_tokens = 0
self.total_tokens = 0
self.logger = logger
def add(
self,
input_tokens: int,
cached_tokens: int,
output_tokens: int,
reasoning_tokens: int,
):
with self.lock:
self.input_tokens += input_tokens
self.cached_tokens += cached_tokens
self.output_tokens += output_tokens
self.reasoning_tokens += reasoning_tokens
self.total_tokens += input_tokens + output_tokens
# self.logger.debug(
# f"Token使用统计 - 输入: {self.input_tokens}(含cached: {self.cached_tokens}), "
# f"输出: {self.output_tokens}(含reasoning: {self.reasoning_tokens}), 总计: {self.total_tokens}"
# )
def get_stats(self):
with self.lock:
return {
"input_tokens": self.input_tokens,
"cached_tokens": self.cached_tokens,
"output_tokens": self.output_tokens,
"reasoning_tokens": self.reasoning_tokens,
"total_tokens": self.total_tokens,
}
def reset(self):
with self.lock:
self.input_tokens = 0
self.cached_tokens = 0
self.output_tokens = 0
self.reasoning_tokens = 0
self.total_tokens = 0
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
ResultHandlerType = Callable[[str, str, logging.Logger], Any]
ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
@@ -90,22 +197,30 @@ ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
class Agent:
_think_factory = {
"open.bigmodel.cn": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
"dashscope.aliyuncs.com": ("enable_thinking ", True, False),
"ark.cn-beijing.volces.com": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
"generativelanguage.googleapis.com": ("extra_body",
{"google": {
"thinking_config": {
"thinking_budget": -1,
"include_thoughts": True
"dashscope.aliyuncs.com": (
"extra_body",
{"enable_thinking": True},
{"enable_thinking": False},
),
"ark.cn-beijing.volces.com": (
"thinking",
{"type": "enabled"},
{"type": "disabled"},
),
"generativelanguage.googleapis.com": (
"extra_body",
{
"google": {
"thinking_config": {"thinking_budget": -1, "include_thoughts": True}
}
},
{
"google": {
"thinking_config": {"thinking_budget": 0, "include_thoughts": False}
}
}, {"google": {
"thinking_config": {
"thinking_budget": 0,
"include_thoughts": False
}
}}),
"api.siliconflow.cn": ("enable_thinking", True, False)
},
),
"api.siliconflow.cn": ("enable_thinking", True, False),
}
def __init__(self, config: AgentConfig):
@@ -126,6 +241,8 @@ class Agent:
# 新增:用于统计最终未解决的错误
self.unresolved_error_lock = Lock()
self.unresolved_error_count = 0
# 新增用于统计token使用情况
self.token_counter = TokenCounter(logger=self.logger)
def _add_thinking_mode(self, data: dict):
if self.domain not in self._think_factory:
@@ -136,16 +253,20 @@ class Agent:
elif self.thinking == "disable":
data[field_thinking] = val_disable
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
):
if temperature is None:
temperature = self.temperature
headers = {"Content-Type": "application/json",
"Authorization": f"Bearer {self.key}"}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.key}",
}
data = {
"model": self.model_id,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
{"role": "user", "content": prompt},
],
"temperature": temperature,
"top_p": top_p,
@@ -154,12 +275,18 @@ class Agent:
self._add_thinking_mode(data)
return headers, data
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
async def send_async(
self,
client: httpx.AsyncClient,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None) -> Any:
best_partial_result: dict | None = None,
) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if pre_send_handler:
@@ -170,21 +297,42 @@ class Agent:
should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None
input_tokens = 0
output_tokens = 0
try:
response = await client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout
timeout=self.timeout,
)
response.raise_for_status()
# print(f"【测试】resp:\n{response.json()}")
result = response.json()["choices"][0]["message"]["content"]
# 获取token使用情况
response_data = response.json()
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
extract_token_info(response_data)
)
# 更新token计数器
self.token_counter.add(
input_tokens, cached_tokens, output_tokens, reasoning_tokens
)
if retry_count > 0:
self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。")
self.logger.info(
f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。"
)
# print(f"result:=============================================================\n{result}\n================\n")
return result if result_handler is None else result_handler(result, prompt, self.logger)
return (
result
if result_handler is None
else result_handler(result, prompt, self.logger)
)
except AgentResultError as e:
self.logger.error(f"AI返回结果有误: {e}")
@@ -199,7 +347,9 @@ class Agent:
# 捕获硬错误
except httpx.HTTPStatusError as e:
self.logger.error(f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}")
self.logger.error(
f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}"
)
should_retry = True
is_hard_error = True
except httpx.RequestError as e:
@@ -220,20 +370,40 @@ class Agent:
if retry_count == 0:
if self.total_error_counter.add():
self.logger.error("错误次数过多,已达到上限,不再重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
return (
best_partial_result
if best_partial_result
else (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
)
elif self.total_error_counter.reach_limit():
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
return (
best_partial_result
if best_partial_result
else (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
)
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
await asyncio.sleep(0.5)
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
return await self.send_async(
client,
prompt,
system_prompt,
retry=True,
retry_count=retry_count + 1,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
best_partial_result=best_partial_result)
best_partial_result=best_partial_result,
)
else:
if should_retry:
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
@@ -245,7 +415,11 @@ class Agent:
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
return best_partial_result
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
return (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
async def send_prompts_async(
self,
@@ -254,17 +428,24 @@ class Agent:
max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]:
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
max_concurrent = (
self.max_concurrent if max_concurrent is None else max_concurrent
)
total = len(prompts)
self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}")
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}"
)
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR
)
# 新增:在每次批量发送前重置计数器
self.unresolved_error_count = 0
# 重置token计数器
self.token_counter.reset()
count = 0
semaphore = asyncio.Semaphore(max_concurrent)
@@ -274,10 +455,13 @@ class Agent:
limits = httpx.Limits(
max_connections=self.max_concurrent * 2, # 为重试和并发预留空间
max_keepalive_connections=self.max_concurrent # 保持活动的连接数
max_keepalive_connections=self.max_concurrent, # 保持活动的连接数
)
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False, limits=limits) as client:
async with httpx.AsyncClient(
trust_env=False, proxies=proxies, verify=False, limits=limits
) as client:
async def send_with_semaphore(p_text: str):
async with semaphore:
result = await self.send_async(
@@ -300,13 +484,32 @@ class Agent:
results = await asyncio.gather(*tasks, return_exceptions=False)
# 新增:在所有任务完成后打印未解决的错误总数
self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}")
self.logger.info(
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
)
# 新增打印token使用统计
token_stats = self.token_counter.get_stats()
self.logger.info(
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
f"总计: {token_stats['total_tokens']/1000:.2f}K"
)
return results
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
pre_send_handler=None, result_handler=None, error_result_handler=None,
best_partial_result: dict | None = None) -> Any:
def send(
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler=None,
result_handler=None,
error_result_handler=None,
best_partial_result: dict | None = None,
) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if pre_send_handler:
@@ -316,21 +519,41 @@ class Agent:
should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None
input_tokens = 0
output_tokens = 0
try:
response = client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout
timeout=self.timeout,
)
response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"]
if retry_count > 0:
self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。")
# 获取token使用情况
response_data = response.json()
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
extract_token_info(response_data)
)
return result if result_handler is None else result_handler(result, prompt, self.logger)
# 更新token计数器
self.token_counter.add(
input_tokens, cached_tokens, output_tokens, reasoning_tokens
)
if retry_count > 0:
self.logger.info(
f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。"
)
return (
result
if result_handler is None
else result_handler(result, prompt, self.logger)
)
except AgentResultError as e:
self.logger.error(f"AI返回结果有误: {e}")
should_retry = True
@@ -343,7 +566,9 @@ class Agent:
# 捕获硬错误
except httpx.HTTPStatusError as e:
self.logger.error(f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}")
self.logger.error(
f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}"
)
should_retry = True
is_hard_error = True
except httpx.RequestError as e:
@@ -364,20 +589,40 @@ class Agent:
if retry_count == 0:
if self.total_error_counter.add():
self.logger.error("错误次数过多,已达到上限,不再重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
return (
best_partial_result
if best_partial_result
else (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
)
elif self.total_error_counter.reach_limit():
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
return best_partial_result if best_partial_result else (
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
return (
best_partial_result
if best_partial_result
else (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
)
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
time.sleep(0.5)
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
return self.send(
client,
prompt,
system_prompt,
retry=True,
retry_count=retry_count + 1,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
best_partial_result=best_partial_result)
best_partial_result=best_partial_result,
)
else:
if should_retry:
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
@@ -389,15 +634,30 @@ class Agent:
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
return best_partial_result
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
return (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
def _send_prompt_count(
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str,
count: PromptsCounter,
pre_send_handler,
result_handler,
error_result_handler) -> Any:
result = self.send(client, prompt, system_prompt, pre_send_handler=pre_send_handler,
error_result_handler,
) -> Any:
result = self.send(
client,
prompt,
system_prompt,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler)
error_result_handler=error_result_handler,
)
count.add()
return result
@@ -407,15 +667,22 @@ class Agent:
system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None
error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]:
self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}")
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
)
self.logger.info(
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
)
self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR
)
# 新增:在每次批量发送前重置计数器
self.unresolved_error_count = 0
# 重置token计数器
self.token_counter.reset()
counter = PromptsCounter(len(prompts), self.logger)
@@ -426,23 +693,41 @@ class Agent:
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
limits = httpx.Limits(
max_connections=self.max_concurrent * 2, # 允许连接复用
max_keepalive_connections=self.max_concurrent # 保持活跃连接
max_keepalive_connections=self.max_concurrent, # 保持活跃连接
)
proxies = get_httpx_proxies() if USE_PROXY else None
with httpx.Client(trust_env=False, proxies=proxies, verify=False, limits=limits) as client:
with httpx.Client(
trust_env=False, proxies=proxies, verify=False, limits=limits
) as client:
clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
results_iterator = executor.map(
self._send_prompt_count,
clients,
prompts,
system_prompts,
counters,
pre_send_handlers,
result_handlers,
error_result_handlers)
error_result_handlers,
)
output_list = list(results_iterator)
# 新增:在所有任务完成后打印未解决的错误总数
self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}")
self.logger.info(
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
)
# 新增打印token使用统计
token_stats = self.token_counter.get_stats()
self.logger.info(
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
f"总计: {token_stats['total_tokens']/1000:.2f}K"
)
return output_list
if __name__ == '__main__':
if __name__ == "__main__":
pass