修复可能的ai无限回复时出现的问题
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
@@ -21,6 +22,7 @@ from docutranslate.logger import global_logger
|
||||
from docutranslate.utils.utils import get_httpx_proxies
|
||||
|
||||
MAX_REQUESTS_PER_ERROR = 15
|
||||
MAX_CONTINUE_FETCHES = 2 # 响应被截断时,最多继续获取的次数
|
||||
|
||||
ThinkingMode = Literal["enable", "disable", "default"]
|
||||
|
||||
@@ -370,15 +372,27 @@ class Agent:
|
||||
error_result_handler: ErrorResultHandlerType,
|
||||
retry_count: int,
|
||||
accumulated_result: str = "",
|
||||
continue_count: int = 0,
|
||||
) -> Any:
|
||||
"""
|
||||
当 finish_reason 为 length 时,继续获取剩余内容
|
||||
当 finish_reason 为 length 时,继续获取剩余内容。
|
||||
注意:很多 API 并不支持这种"继续获取"模式,可能直接返回 stop 或不返回 length。
|
||||
本方法具有退化机制:如果 API 不支持继续获取,会返回已累计的结果。
|
||||
最多继续获取 MAX_CONTINUE_FETCHES 次,防止无限循环。
|
||||
"""
|
||||
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...")
|
||||
if continue_count >= MAX_CONTINUE_FETCHES:
|
||||
self.logger.warning(f"已达到最大继续获取次数 ({MAX_CONTINUE_FETCHES}),返回已累计结果 ({len(accumulated_result)} 字符)")
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
else result_handler(accumulated_result, prompt, self.logger)
|
||||
)
|
||||
|
||||
# 使用空内容继续请求,实际上多数 API 需要用户提供已获取的内容作为上下文
|
||||
# 这里我们发送一个继续信号,让模型继续输出
|
||||
continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]"
|
||||
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符, 第 {continue_count + 1}/{MAX_CONTINUE_FETCHES} 次)...")
|
||||
|
||||
# 构造继续请求的提示
|
||||
# 关键:告知模型我们已经获取了部分内容,请继续完成
|
||||
continue_prompt = f"{prompt}\n\n[系统提示:请继续完成之前的响应。之前已输出内容为:\n---\n{accumulated_result}\n---\n请从中断处继续输出剩余内容。]"
|
||||
|
||||
if pre_send_handler:
|
||||
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
|
||||
@@ -399,8 +413,16 @@ class Agent:
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
additional_result = response_data["choices"][0]["message"]["content"]
|
||||
# 安全提取 choices 和 content
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
self.logger.error(f"API响应中未找到 choices 字段")
|
||||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||||
|
||||
choice = choices[0]
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
message = choice.get("message", {})
|
||||
additional_result = message.get("content", "")
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
@@ -410,7 +432,7 @@ class Agent:
|
||||
# 累加结果
|
||||
accumulated_result += additional_result
|
||||
|
||||
# 如果仍然是 length,继续获取
|
||||
# 如果仍然是 length,继续获取(限制最大轮数防止无限循环)
|
||||
if finish_reason == "length":
|
||||
return await self._continue_fetch_async(
|
||||
client=client,
|
||||
@@ -422,6 +444,7 @@ class Agent:
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
accumulated_result=accumulated_result,
|
||||
continue_count=continue_count + 1,
|
||||
)
|
||||
|
||||
# 非 length 结束,返回累加结果
|
||||
@@ -433,8 +456,9 @@ class Agent:
|
||||
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
|
||||
self.logger.error(f"继续获取内容失败: {repr(e)}")
|
||||
# 返回已获取的部分结果
|
||||
# 退化:返回已获取的部分结果,而不是报错
|
||||
if accumulated_result:
|
||||
self.logger.warning(f"API不支持继续获取,返回已获取的部分结果 ({len(accumulated_result)} 字符)")
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
@@ -489,15 +513,21 @@ class Agent:
|
||||
response_data = response.json()
|
||||
|
||||
# 检查 finish_reason
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
if finish_reason != "stop":
|
||||
# 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等
|
||||
self.logger.warning(
|
||||
f"finish_reason 为 '{finish_reason}',非正常结束。prompt: {prompt[:100]}..."
|
||||
)
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
self.logger.error(f"API响应中未找到 choices 字段")
|
||||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||||
|
||||
# 如果是长度限制,尝试继续获取
|
||||
if finish_reason == "length":
|
||||
finish_reason = choices[0].get("finish_reason", None)
|
||||
result = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
# 处理不同的 finish_reason
|
||||
if finish_reason == "stop":
|
||||
# 正常结束
|
||||
pass
|
||||
elif finish_reason == "length":
|
||||
# 长度限制,尝试继续获取
|
||||
self.logger.warning(f"响应因长度限制被截断,尝试继续获取...")
|
||||
return await self._continue_fetch_async(
|
||||
client=client,
|
||||
prompt=prompt,
|
||||
@@ -507,9 +537,25 @@ class Agent:
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
accumulated_result=result,
|
||||
)
|
||||
|
||||
result = response_data["choices"][0]["message"]["content"]
|
||||
elif finish_reason in ("tool_calls", "function_call"):
|
||||
# 工具调用场景,当前代码可能不支持,直接返回已获取结果
|
||||
self.logger.warning(f"finish_reason 为 '{finish_reason}',当前不支持工具调用,返回已获取内容")
|
||||
return result if result else (
|
||||
prompt if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
elif finish_reason == "content_filter":
|
||||
# 内容被过滤
|
||||
self.logger.error(f"响应内容被过滤")
|
||||
raise ValueError("内容被过滤")
|
||||
elif finish_reason is None:
|
||||
# 某些 API 可能不返回 finish_reason,将其视为正常结束
|
||||
self.logger.warning(f"API未返回 finish_reason,视为正常结束")
|
||||
else:
|
||||
# 其他未知的 finish_reason,记录警告并返回结果
|
||||
self.logger.warning(f"未知的 finish_reason: '{finish_reason}',返回已获取内容")
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
@@ -552,7 +598,7 @@ class Agent:
|
||||
self.logger.error(f"AI请求连接错误 (async): {repr(e)}")
|
||||
should_retry = True
|
||||
is_hard_error = True
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
except (KeyError, IndexError, ValueError, json.JSONDecodeError) as e:
|
||||
self.logger.error(f"AI响应格式或值错误 (async), 将尝试重试: {repr(e)}")
|
||||
should_retry = True
|
||||
is_hard_error = True
|
||||
@@ -714,13 +760,26 @@ class Agent:
|
||||
error_result_handler,
|
||||
retry_count: int,
|
||||
accumulated_result: str = "",
|
||||
continue_count: int = 0,
|
||||
) -> Any:
|
||||
"""
|
||||
当 finish_reason 为 length 时,继续获取剩余内容(同步版本)
|
||||
当 finish_reason 为 length 时,继续获取剩余内容(同步版本)。
|
||||
注意:很多 API 并不支持这种"继续获取"模式,可能直接返回 stop 或不返回 length。
|
||||
本方法具有退化机制:如果 API 不支持继续获取,会返回已累计的结果。
|
||||
最多继续获取 MAX_CONTINUE_FETCHES 次,防止无限循环。
|
||||
"""
|
||||
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...")
|
||||
if continue_count >= MAX_CONTINUE_FETCHES:
|
||||
self.logger.warning(f"已达到最大继续获取次数 ({MAX_CONTINUE_FETCHES}),返回已累计结果 ({len(accumulated_result)} 字符)")
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
else result_handler(accumulated_result, prompt, self.logger)
|
||||
)
|
||||
|
||||
continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]"
|
||||
self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符, 第 {continue_count + 1}/{MAX_CONTINUE_FETCHES} 次)...")
|
||||
|
||||
# 构造继续请求的提示
|
||||
continue_prompt = f"{prompt}\n\n[系统提示:请继续完成之前的响应。之前已输出内容为:\n---\n{accumulated_result}\n---\n请从中断处继续输出剩余内容。]"
|
||||
|
||||
if pre_send_handler:
|
||||
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
|
||||
@@ -740,8 +799,16 @@ class Agent:
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
additional_result = response_data["choices"][0]["message"]["content"]
|
||||
# 安全提取 choices 和 content
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
self.logger.error(f"API响应中未找到 choices 字段")
|
||||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||||
|
||||
choice = choices[0]
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
message = choice.get("message", {})
|
||||
additional_result = message.get("content", "")
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
@@ -761,6 +828,7 @@ class Agent:
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
accumulated_result=accumulated_result,
|
||||
continue_count=continue_count + 1,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -771,7 +839,9 @@ class Agent:
|
||||
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
|
||||
self.logger.error(f"继续获取内容失败: {repr(e)}")
|
||||
# 退化:返回已获取的部分结果,而不是报错
|
||||
if accumulated_result:
|
||||
self.logger.warning(f"API不支持继续获取,返回已获取的部分结果 ({len(accumulated_result)} 字符)")
|
||||
return (
|
||||
accumulated_result
|
||||
if result_handler is None
|
||||
@@ -809,8 +879,6 @@ class Agent:
|
||||
should_retry = False
|
||||
is_hard_error = False
|
||||
current_partial_result = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
@@ -823,15 +891,21 @@ class Agent:
|
||||
response_data = response.json()
|
||||
|
||||
# 检查 finish_reason
|
||||
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None)
|
||||
if finish_reason != "stop":
|
||||
# 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等
|
||||
self.logger.warning(
|
||||
f"finish_reason 为 '{finish_reason}',非正常结束。prompt: {prompt[:100]}..."
|
||||
)
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
self.logger.error(f"API响应中未找到 choices 字段")
|
||||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||||
|
||||
# 如果是长度限制,尝试继续获取
|
||||
if finish_reason == "length":
|
||||
finish_reason = choices[0].get("finish_reason", None)
|
||||
result = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
# 处理不同的 finish_reason
|
||||
if finish_reason == "stop":
|
||||
# 正常结束
|
||||
pass
|
||||
elif finish_reason == "length":
|
||||
# 长度限制,尝试继续获取
|
||||
self.logger.warning(f"响应因长度限制被截断,尝试继续获取...")
|
||||
return self._continue_fetch(
|
||||
client=client,
|
||||
prompt=prompt,
|
||||
@@ -841,9 +915,25 @@ class Agent:
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
retry_count=retry_count,
|
||||
accumulated_result=result,
|
||||
)
|
||||
|
||||
result = response_data["choices"][0]["message"]["content"]
|
||||
elif finish_reason in ("tool_calls", "function_call"):
|
||||
# 工具调用场景,当前代码可能不支持,直接返回已获取结果
|
||||
self.logger.warning(f"finish_reason 为 '{finish_reason}',当前不支持工具调用,返回已获取内容")
|
||||
return result if result else (
|
||||
prompt if error_result_handler is None
|
||||
else error_result_handler(prompt, self.logger)
|
||||
)
|
||||
elif finish_reason == "content_filter":
|
||||
# 内容被过滤
|
||||
self.logger.error(f"响应内容被过滤")
|
||||
raise ValueError("内容被过滤")
|
||||
elif finish_reason is None:
|
||||
# 某些 API 可能不返回 finish_reason,将其视为正常结束
|
||||
self.logger.warning(f"API未返回 finish_reason,视为正常结束")
|
||||
else:
|
||||
# 其他未知的 finish_reason,记录警告并返回结果
|
||||
self.logger.warning(f"未知的 finish_reason: '{finish_reason}',返回已获取内容")
|
||||
|
||||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||||
extract_token_info(response_data)
|
||||
@@ -882,7 +972,7 @@ class Agent:
|
||||
self.logger.error(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
||||
should_retry = True
|
||||
is_hard_error = True
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
except (KeyError, IndexError, ValueError, json.JSONDecodeError) as e:
|
||||
self.logger.error(f"AI响应格式或值错误 (sync), 将尝试重试: {repr(e)}")
|
||||
should_retry = True
|
||||
is_hard_error = True
|
||||
|
||||
Reference in New Issue
Block a user