增加stop标记检测与续传

This commit is contained in:
xunbu
2026-01-05 23:48:40 +08:00
parent a93ab74ce1
commit ea24f7db31

View File

@@ -359,6 +359,94 @@ class Agent:
data["response_format"] = {"type": "json_object"} data["response_format"] = {"type": "json_object"}
return headers, data return headers, data
async def _continue_fetch_async(
self,
client: httpx.AsyncClient,
prompt: str,
system_prompt: str,
force_json: bool,
pre_send_handler: PreSendHandlerType,
result_handler: ResultHandlerType,
error_result_handler: ErrorResultHandlerType,
retry_count: int,
accumulated_result: str = "",
) -> Any:
"""
当 finish_reason 为 length 时,继续获取剩余内容
"""
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...")
# 使用空内容继续请求,实际上多数 API 需要用户提供已获取的内容作为上下文
# 这里我们发送一个继续信号,让模型继续输出
continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]"
if pre_send_handler:
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
# 速率限制检查
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt)
await self.rate_limiter.acquire_async(tokens=estimated_tokens)
headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json)
try:
response = await client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout,
)
response.raise_for_status()
response_data = response.json()
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
additional_result = response_data["choices"][0]["message"]["content"]
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
extract_token_info(response_data)
)
self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens)
# 累加结果
accumulated_result += additional_result
# 如果仍然是 length继续获取
if finish_reason == "length":
return await self._continue_fetch_async(
client=client,
prompt=prompt,
system_prompt=system_prompt,
force_json=force_json,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
retry_count=retry_count,
accumulated_result=accumulated_result,
)
# 非 length 结束,返回累加结果
return (
accumulated_result
if result_handler is None
else result_handler(accumulated_result, prompt, self.logger)
)
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
self.logger.error(f"继续获取内容失败: {repr(e)}")
# 返回已获取的部分结果
if accumulated_result:
return (
accumulated_result
if result_handler is None
else result_handler(accumulated_result, prompt, self.logger)
)
# 如果没有部分结果,调用错误处理器
return (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
async def send_async( async def send_async(
self, self,
client: httpx.AsyncClient, client: httpx.AsyncClient,
@@ -398,9 +486,31 @@ class Agent:
timeout=self.timeout, timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"]
response_data = response.json() response_data = response.json()
# 检查 finish_reason
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
if finish_reason != "stop":
# 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等
self.logger.warning(
f"finish_reason 为 '{finish_reason}'非正常结束。prompt: {prompt[:100]}..."
)
# 如果是长度限制,尝试继续获取
if finish_reason == "length":
return await self._continue_fetch_async(
client=client,
prompt=prompt,
system_prompt=system_prompt,
force_json=force_json,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
retry_count=retry_count,
)
result = response_data["choices"][0]["message"]["content"]
input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
extract_token_info(response_data) extract_token_info(response_data)
) )
@@ -593,6 +703,86 @@ class Agent:
return results return results
def _continue_fetch(
self,
client: httpx.Client,
prompt: str,
system_prompt: str,
force_json: bool,
pre_send_handler,
result_handler,
error_result_handler,
retry_count: int,
accumulated_result: str = "",
) -> Any:
"""
当 finish_reason 为 length 时,继续获取剩余内容(同步版本)
"""
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...")
continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]"
if pre_send_handler:
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt)
self.rate_limiter.acquire_sync(tokens=estimated_tokens)
headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json)
try:
response = client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout,
)
response.raise_for_status()
response_data = response.json()
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
additional_result = response_data["choices"][0]["message"]["content"]
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
extract_token_info(response_data)
)
self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens)
accumulated_result += additional_result
if finish_reason == "length":
return self._continue_fetch(
client=client,
prompt=prompt,
system_prompt=system_prompt,
force_json=force_json,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
retry_count=retry_count,
accumulated_result=accumulated_result,
)
return (
accumulated_result
if result_handler is None
else result_handler(accumulated_result, prompt, self.logger)
)
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
self.logger.error(f"继续获取内容失败: {repr(e)}")
if accumulated_result:
return (
accumulated_result
if result_handler is None
else result_handler(accumulated_result, prompt, self.logger)
)
return (
prompt
if error_result_handler is None
else error_result_handler(prompt, self.logger)
)
def send( def send(
self, self,
client: httpx.Client, client: httpx.Client,
@@ -630,10 +820,31 @@ class Agent:
timeout=self.timeout, timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"]
response_data = response.json() response_data = response.json()
# 检查 finish_reason
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
if finish_reason != "stop":
# 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等
self.logger.warning(
f"finish_reason 为 '{finish_reason}'非正常结束。prompt: {prompt[:100]}..."
)
# 如果是长度限制,尝试继续获取
if finish_reason == "length":
return self._continue_fetch(
client=client,
prompt=prompt,
system_prompt=system_prompt,
force_json=force_json,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
retry_count=retry_count,
)
result = response_data["choices"][0]["message"]["content"]
input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
extract_token_info(response_data) extract_token_info(response_data)
) )