添加token计数,修正阿里云思考模式
This commit is contained in:
@@ -82,6 +82,113 @@ class PromptsCounter:
|
|||||||
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
|
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
|
||||||
|
"""
|
||||||
|
从API响应中提取token信息
|
||||||
|
|
||||||
|
支持多种response格式:
|
||||||
|
1. 格式1: usage.input_tokens_details.cached_tokens 和 usage.output_tokens_details.reasoning_tokens
|
||||||
|
2. 格式2: usage.prompt_tokens_details.cached_tokens
|
||||||
|
3. 格式3: usage.prompt_cache_hit_tokens 和 usage.completion_tokens_details.reasoning_tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_data: API响应数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (input_tokens, cached_tokens, output_tokens, reasoning_tokens)
|
||||||
|
"""
|
||||||
|
if "usage" not in response_data:
|
||||||
|
return 0, 0, 0, 0
|
||||||
|
|
||||||
|
usage = response_data["usage"]
|
||||||
|
input_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
output_tokens = usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
# 初始化token详细统计
|
||||||
|
cached_tokens = 0
|
||||||
|
reasoning_tokens = 0
|
||||||
|
|
||||||
|
# 尝试从不同格式获取cached_tokens
|
||||||
|
# 格式1: input_tokens_details.cached_tokens
|
||||||
|
if (
|
||||||
|
"input_tokens_details" in usage
|
||||||
|
and "cached_tokens" in usage["input_tokens_details"]
|
||||||
|
):
|
||||||
|
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
|
||||||
|
# 格式2: prompt_tokens_details.cached_tokens
|
||||||
|
elif (
|
||||||
|
"prompt_tokens_details" in usage
|
||||||
|
and "cached_tokens" in usage["prompt_tokens_details"]
|
||||||
|
):
|
||||||
|
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||||||
|
# 格式3: prompt_cache_hit_tokens (直接在usage下)
|
||||||
|
elif "prompt_cache_hit_tokens" in usage:
|
||||||
|
cached_tokens = usage["prompt_cache_hit_tokens"]
|
||||||
|
|
||||||
|
# 尝试从不同格式获取reasoning_tokens
|
||||||
|
# 格式1: output_tokens_details.reasoning_tokens
|
||||||
|
if (
|
||||||
|
"output_tokens_details" in usage
|
||||||
|
and "reasoning_tokens" in usage["output_tokens_details"]
|
||||||
|
):
|
||||||
|
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
|
||||||
|
# 格式2: completion_tokens_details.reasoning_tokens
|
||||||
|
elif (
|
||||||
|
"completion_tokens_details" in usage
|
||||||
|
and "reasoning_tokens" in usage["completion_tokens_details"]
|
||||||
|
):
|
||||||
|
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
|
||||||
|
|
||||||
|
return input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCounter:
|
||||||
|
def __init__(self, logger: logging.Logger):
|
||||||
|
self.lock = Lock()
|
||||||
|
self.input_tokens = 0
|
||||||
|
self.cached_tokens = 0
|
||||||
|
self.output_tokens = 0
|
||||||
|
self.reasoning_tokens = 0
|
||||||
|
self.total_tokens = 0
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
input_tokens: int,
|
||||||
|
cached_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
reasoning_tokens: int,
|
||||||
|
):
|
||||||
|
with self.lock:
|
||||||
|
self.input_tokens += input_tokens
|
||||||
|
self.cached_tokens += cached_tokens
|
||||||
|
self.output_tokens += output_tokens
|
||||||
|
self.reasoning_tokens += reasoning_tokens
|
||||||
|
self.total_tokens += input_tokens + output_tokens
|
||||||
|
# self.logger.debug(
|
||||||
|
# f"Token使用统计 - 输入: {self.input_tokens}(含cached: {self.cached_tokens}), "
|
||||||
|
# f"输出: {self.output_tokens}(含reasoning: {self.reasoning_tokens}), 总计: {self.total_tokens}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
with self.lock:
|
||||||
|
return {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"cached_tokens": self.cached_tokens,
|
||||||
|
"output_tokens": self.output_tokens,
|
||||||
|
"reasoning_tokens": self.reasoning_tokens,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
with self.lock:
|
||||||
|
self.input_tokens = 0
|
||||||
|
self.cached_tokens = 0
|
||||||
|
self.output_tokens = 0
|
||||||
|
self.reasoning_tokens = 0
|
||||||
|
self.total_tokens = 0
|
||||||
|
|
||||||
|
|
||||||
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
||||||
ResultHandlerType = Callable[[str, str, logging.Logger], Any]
|
ResultHandlerType = Callable[[str, str, logging.Logger], Any]
|
||||||
ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
|
ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
|
||||||
@@ -90,22 +197,30 @@ ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
|
|||||||
class Agent:
|
class Agent:
|
||||||
_think_factory = {
|
_think_factory = {
|
||||||
"open.bigmodel.cn": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
"open.bigmodel.cn": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
||||||
"dashscope.aliyuncs.com": ("enable_thinking ", True, False),
|
"dashscope.aliyuncs.com": (
|
||||||
"ark.cn-beijing.volces.com": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
|
"extra_body",
|
||||||
"generativelanguage.googleapis.com": ("extra_body",
|
{"enable_thinking": True},
|
||||||
{"google": {
|
{"enable_thinking": False},
|
||||||
"thinking_config": {
|
),
|
||||||
"thinking_budget": -1,
|
"ark.cn-beijing.volces.com": (
|
||||||
"include_thoughts": True
|
"thinking",
|
||||||
|
{"type": "enabled"},
|
||||||
|
{"type": "disabled"},
|
||||||
|
),
|
||||||
|
"generativelanguage.googleapis.com": (
|
||||||
|
"extra_body",
|
||||||
|
{
|
||||||
|
"google": {
|
||||||
|
"thinking_config": {"thinking_budget": -1, "include_thoughts": True}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"google": {
|
||||||
|
"thinking_config": {"thinking_budget": 0, "include_thoughts": False}
|
||||||
}
|
}
|
||||||
}, {"google": {
|
},
|
||||||
"thinking_config": {
|
),
|
||||||
"thinking_budget": 0,
|
"api.siliconflow.cn": ("enable_thinking", True, False),
|
||||||
"include_thoughts": False
|
|
||||||
}
|
|
||||||
}}),
|
|
||||||
"api.siliconflow.cn": ("enable_thinking", True, False)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: AgentConfig):
|
def __init__(self, config: AgentConfig):
|
||||||
@@ -126,6 +241,8 @@ class Agent:
|
|||||||
# 新增:用于统计最终未解决的错误
|
# 新增:用于统计最终未解决的错误
|
||||||
self.unresolved_error_lock = Lock()
|
self.unresolved_error_lock = Lock()
|
||||||
self.unresolved_error_count = 0
|
self.unresolved_error_count = 0
|
||||||
|
# 新增:用于统计token使用情况
|
||||||
|
self.token_counter = TokenCounter(logger=self.logger)
|
||||||
|
|
||||||
def _add_thinking_mode(self, data: dict):
|
def _add_thinking_mode(self, data: dict):
|
||||||
if self.domain not in self._think_factory:
|
if self.domain not in self._think_factory:
|
||||||
@@ -136,16 +253,20 @@ class Agent:
|
|||||||
elif self.thinking == "disable":
|
elif self.thinking == "disable":
|
||||||
data[field_thinking] = val_disable
|
data[field_thinking] = val_disable
|
||||||
|
|
||||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
def _prepare_request_data(
|
||||||
|
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
|
||||||
|
):
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = self.temperature
|
temperature = self.temperature
|
||||||
headers = {"Content-Type": "application/json",
|
headers = {
|
||||||
"Authorization": f"Bearer {self.key}"}
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.key}",
|
||||||
|
}
|
||||||
data = {
|
data = {
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
@@ -154,12 +275,18 @@ class Agent:
|
|||||||
self._add_thinking_mode(data)
|
self._add_thinking_mode(data)
|
||||||
return headers, data
|
return headers, data
|
||||||
|
|
||||||
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
async def send_async(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: None | str = None,
|
||||||
|
retry=True,
|
||||||
retry_count=0,
|
retry_count=0,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None,
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
best_partial_result: dict | None = None) -> Any:
|
best_partial_result: dict | None = None,
|
||||||
|
) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
if pre_send_handler:
|
if pre_send_handler:
|
||||||
@@ -170,21 +297,42 @@ class Agent:
|
|||||||
should_retry = False
|
should_retry = False
|
||||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||||
current_partial_result = None
|
current_partial_result = None
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.baseurl}/chat/completions",
|
f"{self.baseurl}/chat/completions",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
# print(f"【测试】resp:\n{response.json()}")
|
||||||
result = response.json()["choices"][0]["message"]["content"]
|
result = response.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# 获取token使用情况
|
||||||
|
response_data = response.json()
|
||||||
|
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||||
|
extract_token_info(response_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新token计数器
|
||||||
|
self.token_counter.add(
|
||||||
|
input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||||
|
)
|
||||||
|
|
||||||
if retry_count > 0:
|
if retry_count > 0:
|
||||||
self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。")
|
self.logger.info(
|
||||||
|
f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。"
|
||||||
|
)
|
||||||
|
|
||||||
# print(f"result:=============================================================\n{result}\n================\n")
|
# print(f"result:=============================================================\n{result}\n================\n")
|
||||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
return (
|
||||||
|
result
|
||||||
|
if result_handler is None
|
||||||
|
else result_handler(result, prompt, self.logger)
|
||||||
|
)
|
||||||
|
|
||||||
except AgentResultError as e:
|
except AgentResultError as e:
|
||||||
self.logger.error(f"AI返回结果有误: {e}")
|
self.logger.error(f"AI返回结果有误: {e}")
|
||||||
@@ -199,7 +347,9 @@ class Agent:
|
|||||||
|
|
||||||
# 捕获硬错误
|
# 捕获硬错误
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
self.logger.error(f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}")
|
self.logger.error(
|
||||||
|
f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}"
|
||||||
|
)
|
||||||
should_retry = True
|
should_retry = True
|
||||||
is_hard_error = True
|
is_hard_error = True
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
@@ -220,20 +370,40 @@ class Agent:
|
|||||||
if retry_count == 0:
|
if retry_count == 0:
|
||||||
if self.total_error_counter.add():
|
if self.total_error_counter.add():
|
||||||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||||||
return best_partial_result if best_partial_result else (
|
return (
|
||||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
best_partial_result
|
||||||
|
if best_partial_result
|
||||||
|
else (
|
||||||
|
prompt
|
||||||
|
if error_result_handler is None
|
||||||
|
else error_result_handler(prompt, self.logger)
|
||||||
|
)
|
||||||
|
)
|
||||||
elif self.total_error_counter.reach_limit():
|
elif self.total_error_counter.reach_limit():
|
||||||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||||||
return best_partial_result if best_partial_result else (
|
return (
|
||||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
best_partial_result
|
||||||
|
if best_partial_result
|
||||||
|
else (
|
||||||
|
prompt
|
||||||
|
if error_result_handler is None
|
||||||
|
else error_result_handler(prompt, self.logger)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
return await self.send_async(
|
||||||
|
client,
|
||||||
|
prompt,
|
||||||
|
system_prompt,
|
||||||
|
retry=True,
|
||||||
|
retry_count=retry_count + 1,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
best_partial_result=best_partial_result)
|
best_partial_result=best_partial_result,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if should_retry:
|
if should_retry:
|
||||||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||||||
@@ -245,7 +415,11 @@ class Agent:
|
|||||||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||||||
return best_partial_result
|
return best_partial_result
|
||||||
|
|
||||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
return (
|
||||||
|
prompt
|
||||||
|
if error_result_handler is None
|
||||||
|
else error_result_handler(prompt, self.logger)
|
||||||
|
)
|
||||||
|
|
||||||
async def send_prompts_async(
|
async def send_prompts_async(
|
||||||
self,
|
self,
|
||||||
@@ -254,17 +428,24 @@ class Agent:
|
|||||||
max_concurrent: int | None = None,
|
max_concurrent: int | None = None,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
|
max_concurrent = (
|
||||||
|
self.max_concurrent if max_concurrent is None else max_concurrent
|
||||||
|
)
|
||||||
total = len(prompts)
|
total = len(prompts)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}")
|
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature}"
|
||||||
|
)
|
||||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
|
self.total_error_counter.max_errors_count = (
|
||||||
|
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
# 新增:在每次批量发送前重置计数器
|
# 新增:在每次批量发送前重置计数器
|
||||||
self.unresolved_error_count = 0
|
self.unresolved_error_count = 0
|
||||||
|
# 重置token计数器
|
||||||
|
self.token_counter.reset()
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
semaphore = asyncio.Semaphore(max_concurrent)
|
semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
@@ -274,10 +455,13 @@ class Agent:
|
|||||||
|
|
||||||
limits = httpx.Limits(
|
limits = httpx.Limits(
|
||||||
max_connections=self.max_concurrent * 2, # 为重试和并发预留空间
|
max_connections=self.max_concurrent * 2, # 为重试和并发预留空间
|
||||||
max_keepalive_connections=self.max_concurrent # 保持活动的连接数
|
max_keepalive_connections=self.max_concurrent, # 保持活动的连接数
|
||||||
)
|
)
|
||||||
|
|
||||||
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False, limits=limits) as client:
|
async with httpx.AsyncClient(
|
||||||
|
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||||
|
) as client:
|
||||||
|
|
||||||
async def send_with_semaphore(p_text: str):
|
async def send_with_semaphore(p_text: str):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
result = await self.send_async(
|
result = await self.send_async(
|
||||||
@@ -300,13 +484,32 @@ class Agent:
|
|||||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
|
||||||
# 新增:在所有任务完成后打印未解决的错误总数
|
# 新增:在所有任务完成后打印未解决的错误总数
|
||||||
self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}")
|
self.logger.info(
|
||||||
|
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 新增:打印token使用统计
|
||||||
|
token_stats = self.token_counter.get_stats()
|
||||||
|
self.logger.info(
|
||||||
|
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
|
||||||
|
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
|
||||||
|
f"总计: {token_stats['total_tokens']/1000:.2f}K"
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
def send(
|
||||||
pre_send_handler=None, result_handler=None, error_result_handler=None,
|
self,
|
||||||
best_partial_result: dict | None = None) -> Any:
|
client: httpx.Client,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: None | str = None,
|
||||||
|
retry=True,
|
||||||
|
retry_count=0,
|
||||||
|
pre_send_handler=None,
|
||||||
|
result_handler=None,
|
||||||
|
error_result_handler=None,
|
||||||
|
best_partial_result: dict | None = None,
|
||||||
|
) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
if pre_send_handler:
|
if pre_send_handler:
|
||||||
@@ -316,21 +519,41 @@ class Agent:
|
|||||||
should_retry = False
|
should_retry = False
|
||||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||||
current_partial_result = None
|
current_partial_result = None
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"{self.baseurl}/chat/completions",
|
f"{self.baseurl}/chat/completions",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()["choices"][0]["message"]["content"]
|
result = response.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
if retry_count > 0:
|
# 获取token使用情况
|
||||||
self.logger.info(f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。")
|
response_data = response.json()
|
||||||
|
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||||
|
extract_token_info(response_data)
|
||||||
|
)
|
||||||
|
|
||||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
# 更新token计数器
|
||||||
|
self.token_counter.add(
|
||||||
|
input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if retry_count > 0:
|
||||||
|
self.logger.info(
|
||||||
|
f"重试成功 (第 {retry_count}/{MAX_RETRY_COUNT} 次尝试)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
result
|
||||||
|
if result_handler is None
|
||||||
|
else result_handler(result, prompt, self.logger)
|
||||||
|
)
|
||||||
except AgentResultError as e:
|
except AgentResultError as e:
|
||||||
self.logger.error(f"AI返回结果有误: {e}")
|
self.logger.error(f"AI返回结果有误: {e}")
|
||||||
should_retry = True
|
should_retry = True
|
||||||
@@ -343,7 +566,9 @@ class Agent:
|
|||||||
|
|
||||||
# 捕获硬错误
|
# 捕获硬错误
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
self.logger.error(f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}")
|
self.logger.error(
|
||||||
|
f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}"
|
||||||
|
)
|
||||||
should_retry = True
|
should_retry = True
|
||||||
is_hard_error = True
|
is_hard_error = True
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
@@ -364,20 +589,40 @@ class Agent:
|
|||||||
if retry_count == 0:
|
if retry_count == 0:
|
||||||
if self.total_error_counter.add():
|
if self.total_error_counter.add():
|
||||||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||||||
return best_partial_result if best_partial_result else (
|
return (
|
||||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
best_partial_result
|
||||||
|
if best_partial_result
|
||||||
|
else (
|
||||||
|
prompt
|
||||||
|
if error_result_handler is None
|
||||||
|
else error_result_handler(prompt, self.logger)
|
||||||
|
)
|
||||||
|
)
|
||||||
elif self.total_error_counter.reach_limit():
|
elif self.total_error_counter.reach_limit():
|
||||||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||||||
return best_partial_result if best_partial_result else (
|
return (
|
||||||
prompt if error_result_handler is None else error_result_handler(prompt, self.logger))
|
best_partial_result
|
||||||
|
if best_partial_result
|
||||||
|
else (
|
||||||
|
prompt
|
||||||
|
if error_result_handler is None
|
||||||
|
else error_result_handler(prompt, self.logger)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
self.logger.info(f"正在重试第 {retry_count + 1}/{MAX_RETRY_COUNT} 次...")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
return self.send(
|
||||||
|
client,
|
||||||
|
prompt,
|
||||||
|
system_prompt,
|
||||||
|
retry=True,
|
||||||
|
retry_count=retry_count + 1,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
best_partial_result=best_partial_result)
|
best_partial_result=best_partial_result,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if should_retry:
|
if should_retry:
|
||||||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||||||
@@ -389,15 +634,30 @@ class Agent:
|
|||||||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||||||
return best_partial_result
|
return best_partial_result
|
||||||
|
|
||||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
return (
|
||||||
|
prompt
|
||||||
|
if error_result_handler is None
|
||||||
|
else error_result_handler(prompt, self.logger)
|
||||||
|
)
|
||||||
|
|
||||||
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
|
def _send_prompt_count(
|
||||||
|
self,
|
||||||
|
client: httpx.Client,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: None | str,
|
||||||
|
count: PromptsCounter,
|
||||||
pre_send_handler,
|
pre_send_handler,
|
||||||
result_handler,
|
result_handler,
|
||||||
error_result_handler) -> Any:
|
error_result_handler,
|
||||||
result = self.send(client, prompt, system_prompt, pre_send_handler=pre_send_handler,
|
) -> Any:
|
||||||
|
result = self.send(
|
||||||
|
client,
|
||||||
|
prompt,
|
||||||
|
system_prompt,
|
||||||
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler)
|
error_result_handler=error_result_handler,
|
||||||
|
)
|
||||||
count.add()
|
count.add()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -407,15 +667,22 @@ class Agent:
|
|||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}")
|
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
|
||||||
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
|
)
|
||||||
self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR
|
self.logger.info(
|
||||||
|
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
|
||||||
|
)
|
||||||
|
self.total_error_counter.max_errors_count = (
|
||||||
|
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
# 新增:在每次批量发送前重置计数器
|
# 新增:在每次批量发送前重置计数器
|
||||||
self.unresolved_error_count = 0
|
self.unresolved_error_count = 0
|
||||||
|
# 重置token计数器
|
||||||
|
self.token_counter.reset()
|
||||||
|
|
||||||
counter = PromptsCounter(len(prompts), self.logger)
|
counter = PromptsCounter(len(prompts), self.logger)
|
||||||
|
|
||||||
@@ -426,23 +693,41 @@ class Agent:
|
|||||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||||
limits = httpx.Limits(
|
limits = httpx.Limits(
|
||||||
max_connections=self.max_concurrent * 2, # 允许连接复用
|
max_connections=self.max_concurrent * 2, # 允许连接复用
|
||||||
max_keepalive_connections=self.max_concurrent # 保持活跃连接
|
max_keepalive_connections=self.max_concurrent, # 保持活跃连接
|
||||||
)
|
)
|
||||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||||
with httpx.Client(trust_env=False, proxies=proxies, verify=False, limits=limits) as client:
|
with httpx.Client(
|
||||||
|
trust_env=False, proxies=proxies, verify=False, limits=limits
|
||||||
|
) as client:
|
||||||
clients = itertools.repeat(client, len(prompts))
|
clients = itertools.repeat(client, len(prompts))
|
||||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||||
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
results_iterator = executor.map(
|
||||||
|
self._send_prompt_count,
|
||||||
|
clients,
|
||||||
|
prompts,
|
||||||
|
system_prompts,
|
||||||
|
counters,
|
||||||
pre_send_handlers,
|
pre_send_handlers,
|
||||||
result_handlers,
|
result_handlers,
|
||||||
error_result_handlers)
|
error_result_handlers,
|
||||||
|
)
|
||||||
output_list = list(results_iterator)
|
output_list = list(results_iterator)
|
||||||
|
|
||||||
# 新增:在所有任务完成后打印未解决的错误总数
|
# 新增:在所有任务完成后打印未解决的错误总数
|
||||||
self.logger.info(f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}")
|
self.logger.info(
|
||||||
|
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 新增:打印token使用统计
|
||||||
|
token_stats = self.token_counter.get_stats()
|
||||||
|
self.logger.info(
|
||||||
|
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
|
||||||
|
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
|
||||||
|
f"总计: {token_stats['total_tokens']/1000:.2f}K"
|
||||||
|
)
|
||||||
|
|
||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user