到达上限时也应该增加未解决错误。

This commit is contained in:
xunbu
2025-09-13 11:02:18 +08:00
parent 66606f077c
commit 29fb97d338

View File

@@ -111,14 +111,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取cached_tokens
# 格式1: input_tokens_details.cached_tokens
if (
"input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"]
"input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"]
):
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
# 格式2: prompt_tokens_details.cached_tokens
elif (
"prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"]
"prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"]
):
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
# 格式3: prompt_cache_hit_tokens (直接在usage下)
@@ -128,14 +128,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取reasoning_tokens
# 格式1: output_tokens_details.reasoning_tokens
if (
"output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"]
"output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"]
):
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
# 格式2: completion_tokens_details.reasoning_tokens
elif (
"completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"]
"completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"]
):
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
@@ -153,11 +153,11 @@ class TokenCounter:
self.logger = logger
def add(
self,
input_tokens: int,
cached_tokens: int,
output_tokens: int,
reasoning_tokens: int,
self,
input_tokens: int,
cached_tokens: int,
output_tokens: int,
reasoning_tokens: int,
):
with self.lock:
self.input_tokens += input_tokens
@@ -256,7 +256,7 @@ class Agent:
data[field_thinking] = val_disable
def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
):
if temperature is None:
temperature = self.temperature
@@ -278,16 +278,16 @@ class Agent:
return headers, data
async def send_async(
self,
client: httpx.AsyncClient,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None,
self,
client: httpx.AsyncClient,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None,
) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
@@ -325,9 +325,7 @@ class Agent:
)
if retry_count > 0:
self.logger.info(
f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。"
)
self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。")
# print(f"result:=============================================================\n{result}\n================\n")
return (
@@ -372,6 +370,9 @@ class Agent:
if retry_count == 0:
if self.total_error_counter.add():
self.logger.error("错误次数过多,已达到上限,不再重试。")
# 新增:当因为达到错误上限而不再重试时,增加未解决错误计数
with self.unresolved_error_lock:
self.unresolved_error_count += 1
return (
best_partial_result
if best_partial_result
@@ -383,6 +384,9 @@ class Agent:
)
elif self.total_error_counter.reach_limit():
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
# 新增:当因为达到错误上限而不再重试时,增加未解决错误计数
with self.unresolved_error_lock:
self.unresolved_error_count += 1
return (
best_partial_result
if best_partial_result
@@ -424,13 +428,13 @@ class Agent:
)
async def send_prompts_async(
self,
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
self,
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]:
max_concurrent = (
self.max_concurrent if max_concurrent is None else max_concurrent
@@ -441,7 +445,7 @@ class Agent:
)
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR
len(prompts) // MAX_REQUESTS_PER_ERROR
)
# 新增:在每次批量发送前重置计数器
@@ -461,8 +465,9 @@ class Agent:
)
async with httpx.AsyncClient(
trust_env=False, proxies=proxies, verify=False, limits=limits
trust_env=False, proxies=proxies, verify=False, limits=limits
) as client:
async def send_with_semaphore(p_text: str):
async with semaphore:
result = await self.send_async(
@@ -500,16 +505,16 @@ class Agent:
return results
def send(
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler=None,
result_handler=None,
error_result_handler=None,
best_partial_result: dict | None = None,
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler=None,
result_handler=None,
error_result_handler=None,
best_partial_result: dict | None = None,
) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
@@ -546,9 +551,7 @@ class Agent:
)
if retry_count > 0:
self.logger.info(
f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。"
)
self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。")
return (
result
@@ -590,6 +593,9 @@ class Agent:
if retry_count == 0:
if self.total_error_counter.add():
self.logger.error("错误次数过多,已达到上限,不再重试。")
# 新增:当因为达到错误上限而不再重试时,增加未解决错误计数
with self.unresolved_error_lock:
self.unresolved_error_count += 1
return (
best_partial_result
if best_partial_result
@@ -601,6 +607,9 @@ class Agent:
)
elif self.total_error_counter.reach_limit():
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
# 新增:当因为达到错误上限而不再重试时,增加未解决错误计数
with self.unresolved_error_lock:
self.unresolved_error_count += 1
return (
best_partial_result
if best_partial_result
@@ -642,14 +651,14 @@ class Agent:
)
def _send_prompt_count(
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str,
count: PromptsCounter,
pre_send_handler,
result_handler,
error_result_handler,
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str,
count: PromptsCounter,
pre_send_handler,
result_handler,
error_result_handler,
) -> Any:
result = self.send(
client,
@@ -663,12 +672,12 @@ class Agent:
return result
def send_prompts(
self,
prompts: list[str],
system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
self,
prompts: list[str],
system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]:
self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
@@ -677,7 +686,7 @@ class Agent:
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
)
self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR
len(prompts) // MAX_REQUESTS_PER_ERROR
)
# 新增:在每次批量发送前重置计数器
@@ -698,7 +707,7 @@ class Agent:
)
proxies = get_httpx_proxies() if USE_PROXY else None
with httpx.Client(
trust_env=False, proxies=proxies, verify=False, limits=limits
trust_env=False, proxies=proxies, verify=False, limits=limits
) as client:
clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: