增加stop标记检测与续传
This commit is contained in:
@@ -359,6 +359,94 @@ class Agent:
|
||||
data["response_format"] = {"type": "json_object"}
|
||||
return headers, data
|
||||
|
||||
async def _continue_fetch_async(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
force_json: bool,
|
||||
pre_send_handler: PreSendHandlerType,
|
||||
result_handler: ResultHandlerType,
|
||||
error_result_handler: ErrorResultHandlerType,
|
||||
retry_count: int,
|
||||
accumulated_result: str = "",
|
||||
) -> Any:
|
||||
"""
|
||||
当 finish_reason 为 length 时,继续获取剩余内容
|
||||
"""
|
||||
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...")
|
||||
|
||||
# 使用空内容继续请求,实际上多数 API 需要用户提供已获取的内容作为上下文
|
||||
# 这里我们发送一个继续信号,让模型继续输出
|
||||
continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]"
|
||||
|
||||
if pre_send_handler:
|
||||
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
|
||||
|
||||
# 速率限制检查
|
||||
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt)
|
||||
await self.rate_limiter.acquire_async(tokens=estimated_tokens)
|
||||
|
||||
headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json)
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
additional_result = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
)
|
||||
self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens)
|
||||
|
||||
# 累加结果
|
||||
accumulated_result += additional_result
|
||||
|
||||
# 如果仍然是 length,继续获取
|
||||
if finish_reason == "length":
|
||||
return await self._continue_fetch_async(
|
||||
client=client,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
force_json=force_json,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
accumulated_result=accumulated_result,
|
||||
)
|
||||
|
||||
# 非 length 结束,返回累加结果
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
else result_handler(accumulated_result, prompt, self.logger)
|
||||
)
|
||||
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
|
||||
self.logger.error(f"继续获取内容失败: {repr(e)}")
|
||||
# 返回已获取的部分结果
|
||||
if accumulated_result:
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
else result_handler(accumulated_result, prompt, self.logger)
|
||||
)
|
||||
# 如果没有部分结果,调用错误处理器
|
||||
return (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
|
||||
async def send_async(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
@@ -398,9 +486,31 @@ class Agent:
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查 finish_reason
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
if finish_reason != "stop":
|
||||
# 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等
|
||||
self.logger.warning(
|
||||
f"finish_reason 为 '{finish_reason}',非正常结束。prompt: {prompt[:100]}..."
|
||||
)
|
||||
|
||||
# 如果是长度限制,尝试继续获取
|
||||
if finish_reason == "length":
|
||||
return await self._continue_fetch_async(
|
||||
client=client,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
force_json=force_json,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
)
|
||||
|
||||
result = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
)
|
||||
@@ -593,6 +703,86 @@ class Agent:
|
||||
|
||||
return results
|
||||
|
||||
def _continue_fetch(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
force_json: bool,
|
||||
pre_send_handler,
|
||||
result_handler,
|
||||
error_result_handler,
|
||||
retry_count: int,
|
||||
accumulated_result: str = "",
|
||||
) -> Any:
|
||||
"""
|
||||
当 finish_reason 为 length 时,继续获取剩余内容(同步版本)
|
||||
"""
|
||||
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...")
|
||||
|
||||
continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]"
|
||||
|
||||
if pre_send_handler:
|
||||
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
|
||||
|
||||
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt)
|
||||
self.rate_limiter.acquire_sync(tokens=estimated_tokens)
|
||||
|
||||
headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
additional_result = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
)
|
||||
self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens)
|
||||
|
||||
accumulated_result += additional_result
|
||||
|
||||
if finish_reason == "length":
|
||||
return self._continue_fetch(
|
||||
client=client,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
force_json=force_json,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
accumulated_result=accumulated_result,
|
||||
)
|
||||
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
else result_handler(accumulated_result, prompt, self.logger)
|
||||
)
|
||||
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
|
||||
self.logger.error(f"继续获取内容失败: {repr(e)}")
|
||||
if accumulated_result:
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
else result_handler(accumulated_result, prompt, self.logger)
|
||||
)
|
||||
return (
|
||||
prompt
|
||||
if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
|
||||
def send(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
@@ -630,10 +820,31 @@ class Agent:
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查 finish_reason
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
if finish_reason != "stop":
|
||||
# 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等
|
||||
self.logger.warning(
|
||||
f"finish_reason 为 '{finish_reason}',非正常结束。prompt: {prompt[:100]}..."
|
||||
)
|
||||
|
||||
# 如果是长度限制,尝试继续获取
|
||||
if finish_reason == "length":
|
||||
return self._continue_fetch(
|
||||
client=client,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
force_json=force_json,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
)
|
||||
|
||||
result = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user