1436 lines
54 KiB
Python
1436 lines
54 KiB
Python
# SPDX-FileCopyrightText: 2025 QinHan
|
||
# SPDX-License-Identifier: MPL-2.0
|
||
|
||
import asyncio
|
||
import itertools
|
||
import json
|
||
import logging
|
||
import re # 新增:用于正则估算
|
||
import time
|
||
from collections import deque
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from dataclasses import dataclass
|
||
from threading import Lock
|
||
from typing import Literal, Callable, Any
|
||
from urllib.parse import urlparse
|
||
|
||
import httpx
|
||
|
||
from docutranslate.agents.provider import get_provider_by_domain
|
||
from docutranslate.agents.thinking.thinking_factory import get_thinking_mode, ProviderType
|
||
from docutranslate.logger import global_logger
|
||
from docutranslate.utils.utils import get_httpx_proxies
|
||
|
||
MAX_REQUESTS_PER_ERROR = 15
|
||
MAX_CONTINUE_FETCHES = 2 # 响应被截断时,最多继续获取的次数
|
||
|
||
ThinkingMode = Literal["enable", "disable", "default"]
|
||
|
||
|
||
class AgentResultError(ValueError):
|
||
"""一个特殊的异常,用于表示结果由AI正常返回,但返回的结果有问题。该错误不计入总错误数"""
|
||
|
||
def __init__(self, message):
|
||
super().__init__(message)
|
||
|
||
|
||
class PartialAgentResultError(ValueError):
|
||
"""一个特殊的异常,用于表示结果不完整但包含了部分成功的数据,以便触发重试。该错误不计入总错误数"""
|
||
|
||
def __init__(self, message, partial_result: dict, append_prompt: str = None):
|
||
super().__init__(message)
|
||
self.partial_result = partial_result
|
||
self.append_prompt = append_prompt
|
||
|
||
|
||
@dataclass(kw_only=True)
|
||
class AgentConfig:
|
||
logger: logging.Logger = global_logger
|
||
base_url: str
|
||
api_key: str | None = None
|
||
model_id: str
|
||
temperature: float = 0.7
|
||
concurrent: int = 30
|
||
timeout: int = 1200
|
||
thinking: ThinkingMode = "disable"
|
||
retry: int = 2
|
||
system_proxy_enable: bool = False
|
||
force_json: bool = False
|
||
rpm: int | None = None # 每分钟请求数限制
|
||
tpm: int | None = None # 每分钟Token数限制
|
||
provider: ProviderType | None = None
|
||
source_lang: str | None = None # qwen-mt: 源语言
|
||
|
||
|
||
class TotalErrorCounter:
|
||
def __init__(self, logger: logging.Logger, max_errors_count=10):
|
||
self.lock = Lock()
|
||
self.count = 0
|
||
self.logger = logger
|
||
self.max_errors_count = max_errors_count
|
||
|
||
def add(self):
|
||
with self.lock:
|
||
self.count += 1
|
||
if self.count > self.max_errors_count:
|
||
self.logger.info(f"错误响应过多")
|
||
return self.reach_limit()
|
||
|
||
def reach_limit(self):
|
||
return self.count > self.max_errors_count
|
||
|
||
|
||
class PromptsCounter:
|
||
def __init__(self, total: int, logger: logging.Logger):
|
||
self.lock = Lock()
|
||
self.count = 0
|
||
self.total = total
|
||
self.logger = logger
|
||
|
||
def add(self):
|
||
with self.lock:
|
||
self.count += 1
|
||
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
|
||
|
||
|
||
# --- 新增 RateLimiter 类 ---
|
||
class RateLimiter:
|
||
"""
|
||
基于滑动窗口的速率限制器,支持 RPM 和 TPM 控制。
|
||
同时支持 Async 和 Sync 调用。
|
||
"""
|
||
|
||
def __init__(self, rpm: int | None, tpm: int | None):
|
||
self.rpm = rpm
|
||
self.tpm = tpm
|
||
# 双端队列存储 (timestamp, value),value对于RPM是1,对于TPM是token数量
|
||
self.request_timestamps = deque()
|
||
self.token_timestamps = deque()
|
||
self.lock = Lock() # 用于同步模式和保护共享数据
|
||
|
||
def _cleanup_window(self, now: float):
|
||
"""清理60秒窗口之前的数据"""
|
||
window_start = now - 60.0
|
||
|
||
while self.request_timestamps and self.request_timestamps[0] <= window_start:
|
||
self.request_timestamps.popleft()
|
||
|
||
while self.token_timestamps and self.token_timestamps[0][0] <= window_start:
|
||
self.token_timestamps.popleft()
|
||
|
||
def _check_and_get_wait_time(self, tokens: int) -> float:
|
||
"""检查是否满足限制,返回需要等待的秒数。如果不需等待返回 0"""
|
||
now = time.time()
|
||
self._cleanup_window(now)
|
||
|
||
wait_time = 0.0
|
||
|
||
# Check RPM
|
||
if self.rpm and len(self.request_timestamps) >= self.rpm:
|
||
earliest = self.request_timestamps[0]
|
||
wait_time = max(wait_time, 60 - (now - earliest))
|
||
|
||
# Check TPM
|
||
if self.tpm:
|
||
current_tokens = sum(t[1] for t in self.token_timestamps)
|
||
if current_tokens + tokens > self.tpm:
|
||
if self.token_timestamps:
|
||
earliest = self.token_timestamps[0][0]
|
||
wait_time = max(wait_time, 60 - (now - earliest))
|
||
else:
|
||
pass
|
||
|
||
return wait_time
|
||
|
||
def _record_usage(self, tokens: int):
|
||
"""记录使用量"""
|
||
now = time.time()
|
||
if self.rpm is not None:
|
||
self.request_timestamps.append(now)
|
||
if self.tpm is not None:
|
||
self.token_timestamps.append((now, tokens))
|
||
|
||
async def acquire_async(self, tokens: int = 0):
|
||
"""异步等待配额"""
|
||
if self.rpm is None and self.tpm is None:
|
||
return
|
||
|
||
while True:
|
||
# print(f"[RateLimiter-Async] 准备获取锁...")
|
||
with self.lock:
|
||
# print(f"[RateLimiter-Async] 已加锁 (Checking)")
|
||
|
||
wait_time = self._check_and_get_wait_time(tokens)
|
||
if wait_time <= 0:
|
||
self._record_usage(tokens)
|
||
# print(f"[RateLimiter-Async] 释放锁 (成功获取配额)")
|
||
return
|
||
|
||
# print(f"[RateLimiter-Async] 释放锁 (需等待 {wait_time:.2f}s)")
|
||
|
||
# 释放锁后等待
|
||
await asyncio.sleep(wait_time + 0.1)
|
||
|
||
def acquire_sync(self, tokens: int = 0):
|
||
"""同步等待配额(线程阻塞)"""
|
||
if self.rpm is None and self.tpm is None:
|
||
return
|
||
|
||
while True:
|
||
# print(f"[RateLimiter-Sync] 准备获取锁...")
|
||
with self.lock:
|
||
# print(f"[RateLimiter-Sync] 已加锁 (Checking)")
|
||
|
||
wait_time = self._check_and_get_wait_time(tokens)
|
||
if wait_time <= 0:
|
||
self._record_usage(tokens)
|
||
# print(f"[RateLimiter-Sync] 释放锁 (成功获取配额)")
|
||
return
|
||
|
||
# print(f"[RateLimiter-Sync] 释放锁 (需等待 {wait_time:.2f}s)")
|
||
|
||
time.sleep(wait_time + 0.1)
|
||
|
||
|
||
def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
|
||
"""(保持原样) 从API响应中提取token信息"""
|
||
if "usage" not in response_data:
|
||
return 0, 0, 0, 0
|
||
|
||
usage = response_data["usage"]
|
||
input_tokens = usage.get("prompt_tokens", 0)
|
||
output_tokens = usage.get("completion_tokens", 0)
|
||
|
||
cached_tokens = 0
|
||
reasoning_tokens = 0
|
||
try:
|
||
if (
|
||
"input_tokens_details" in usage
|
||
and "cached_tokens" in usage["input_tokens_details"]
|
||
):
|
||
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
|
||
elif (
|
||
"prompt_tokens_details" in usage
|
||
and "cached_tokens" in usage["prompt_tokens_details"]
|
||
):
|
||
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||
elif "prompt_cache_hit_tokens" in usage:
|
||
cached_tokens = usage["prompt_cache_hit_tokens"]
|
||
|
||
if (
|
||
"output_tokens_details" in usage
|
||
and "reasoning_tokens" in usage["output_tokens_details"]
|
||
):
|
||
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
|
||
elif (
|
||
"completion_tokens_details" in usage
|
||
and "reasoning_tokens" in usage["completion_tokens_details"]
|
||
):
|
||
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
|
||
return input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||
except TypeError:
|
||
return -1, -1, -1, -1
|
||
|
||
|
||
class TokenCounter:
|
||
def __init__(self, logger: logging.Logger):
|
||
self.lock = Lock()
|
||
self.input_tokens = 0
|
||
self.cached_tokens = 0
|
||
self.output_tokens = 0
|
||
self.reasoning_tokens = 0
|
||
self.total_tokens = 0
|
||
self.logger = logger
|
||
|
||
def add(
|
||
self,
|
||
input_tokens: int,
|
||
cached_tokens: int,
|
||
output_tokens: int,
|
||
reasoning_tokens: int,
|
||
):
|
||
with self.lock:
|
||
self.input_tokens += input_tokens
|
||
self.cached_tokens += cached_tokens
|
||
self.output_tokens += output_tokens
|
||
self.reasoning_tokens += reasoning_tokens
|
||
self.total_tokens += input_tokens + output_tokens
|
||
|
||
def get_stats(self):
|
||
with self.lock:
|
||
return {
|
||
"input_tokens": self.input_tokens,
|
||
"cached_tokens": self.cached_tokens,
|
||
"output_tokens": self.output_tokens,
|
||
"reasoning_tokens": self.reasoning_tokens,
|
||
"total_tokens": self.total_tokens,
|
||
}
|
||
|
||
def reset(self):
|
||
with self.lock:
|
||
self.input_tokens = 0
|
||
self.cached_tokens = 0
|
||
self.output_tokens = 0
|
||
self.reasoning_tokens = 0
|
||
self.total_tokens = 0
|
||
|
||
|
||
PreSendHandlerType = Callable[[str, str], tuple[str, str]]
|
||
ResultHandlerType = Callable[[str, str, logging.Logger], Any]
|
||
ErrorResultHandlerType = Callable[[str, logging.Logger], Any]
|
||
|
||
# _CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]')
|
||
# 扩展正则范围,包含:
|
||
# CJK (中日韩): \u2e80-\u9fff
|
||
# 西里尔 (俄语等): \u0400-\u04ff
|
||
# 阿拉伯语: \u0600-\u06ff
|
||
# 泰语: \u0e00-\u0e7f
|
||
# 梵文 (印地语等): \u0900-\u097f
|
||
# 标点和特殊符号范围较广,这里主要抓取非拉丁体系的主要语言
|
||
_COMPLEX_SCRIPT_PATTERN = re.compile(
|
||
r'[\u2e80-\u9fff\u0400-\u04ff\u0600-\u06ff\u0e00-\u0e7f\u0900-\u097f]'
|
||
)
|
||
|
||
def _normalize_mt_lang_key(lang: str) -> str:
|
||
key = str(lang).strip().lower()
|
||
key = key.replace("_", "-")
|
||
key = key.replace("'", "'").replace("'", "'")
|
||
key = key.replace("–", "-").replace("—", "-")
|
||
key = re.sub(r"\s+", " ", key)
|
||
return key
|
||
|
||
|
||
_MT_LANG_BY_CODE = {
|
||
"en": "English",
|
||
"zh": "Chinese",
|
||
"zh-tw": "Traditional Chinese",
|
||
"ru": "Russian",
|
||
"ja": "Japanese",
|
||
"ko": "Korean",
|
||
"es": "Spanish",
|
||
"fr": "French",
|
||
"pt": "Portuguese",
|
||
"de": "German",
|
||
"it": "Italian",
|
||
"th": "Thai",
|
||
"vi": "Vietnamese",
|
||
"id": "Indonesian",
|
||
"ms": "Malay",
|
||
"ar": "Arabic",
|
||
"hi": "Hindi",
|
||
"he": "Hebrew",
|
||
"my": "Burmese",
|
||
"ta": "Tamil",
|
||
"ur": "Urdu",
|
||
"bn": "Bengali",
|
||
"pl": "Polish",
|
||
"nl": "Dutch",
|
||
"ro": "Romanian",
|
||
"tr": "Turkish",
|
||
"km": "Khmer",
|
||
"lo": "Lao",
|
||
"yue": "Cantonese",
|
||
"cs": "Czech",
|
||
"el": "Greek",
|
||
"sv": "Swedish",
|
||
"hu": "Hungarian",
|
||
"da": "Danish",
|
||
"fi": "Finnish",
|
||
"uk": "Ukrainian",
|
||
"bg": "Bulgarian",
|
||
"sr": "Serbian",
|
||
"te": "Telugu",
|
||
"af": "Afrikaans",
|
||
"hy": "Armenian",
|
||
"as": "Assamese",
|
||
"ast": "Asturian",
|
||
"eu": "Basque",
|
||
"be": "Belarusian",
|
||
"bs": "Bosnian",
|
||
"ca": "Catalan",
|
||
"ceb": "Cebuano",
|
||
"hr": "Croatian",
|
||
"arz": "Egyptian Arabic",
|
||
"et": "Estonian",
|
||
"gl": "Galician",
|
||
"ka": "Georgian",
|
||
"gu": "Gujarati",
|
||
"is": "Icelandic",
|
||
"jv": "Javanese",
|
||
"kn": "Kannada",
|
||
"kk": "Kazakh",
|
||
"lv": "Latvian",
|
||
"lt": "Lithuanian",
|
||
"lb": "Luxembourgish",
|
||
"mk": "Macedonian",
|
||
"mai": "Maithili",
|
||
"mt": "Maltese",
|
||
"mr": "Marathi",
|
||
"acm": "Mesopotamian Arabic",
|
||
"ary": "Moroccan Arabic",
|
||
"ars": "Najdi Arabic",
|
||
"ne": "Nepali",
|
||
"az": "North Azerbaijani",
|
||
"apc": "North Levantine Arabic",
|
||
"uz": "Northern Uzbek",
|
||
"nb": "Norwegian Bokmål",
|
||
"nn": "Norwegian Nynorsk",
|
||
"oc": "Occitan",
|
||
"or": "Odia",
|
||
"pag": "Pangasinan",
|
||
"scn": "Sicilian",
|
||
"sd": "Sindhi",
|
||
"si": "Sinhala",
|
||
"sk": "Slovak",
|
||
"sl": "Slovenian",
|
||
"ajp": "South Levantine Arabic",
|
||
"sw": "Swahili",
|
||
"tl": "Tagalog",
|
||
"acq": "Ta'izzi-Adeni Arabic",
|
||
"sq": "Tosk Albanian",
|
||
"aeb": "Tunisian Arabic",
|
||
"vec": "Venetian",
|
||
"war": "Waray",
|
||
"cy": "Welsh",
|
||
"fa": "Western Persian",
|
||
}
|
||
|
||
_MT_LANG_BY_NAME = {
|
||
_normalize_mt_lang_key(name): name for name in set(_MT_LANG_BY_CODE.values())
|
||
}
|
||
|
||
_MT_LANG_ALIASES = {
|
||
# Existing UI/common aliases
|
||
"english": "English",
|
||
"英语": "English",
|
||
"英文": "English",
|
||
"简体中文": "Chinese",
|
||
"中文": "Chinese",
|
||
"simplified chinese": "Chinese",
|
||
"chinese": "Chinese",
|
||
"traditional chinese": "Traditional Chinese",
|
||
"繁体中文": "Traditional Chinese",
|
||
"zh-hans": "Chinese",
|
||
"zh-cn": "Chinese",
|
||
"zh-hant": "Traditional Chinese",
|
||
# Full Chinese aliases from qwen-mt language list
|
||
"俄语": "Russian",
|
||
"日语": "Japanese",
|
||
"韩语": "Korean",
|
||
"西班牙语": "Spanish",
|
||
"法语": "French",
|
||
"葡萄牙语": "Portuguese",
|
||
"德语": "German",
|
||
"意大利语": "Italian",
|
||
"泰语": "Thai",
|
||
"越南语": "Vietnamese",
|
||
"印度尼西亚语": "Indonesian",
|
||
"马来语": "Malay",
|
||
"阿拉伯语": "Arabic",
|
||
"印地语": "Hindi",
|
||
"希伯来语": "Hebrew",
|
||
"缅甸语": "Burmese",
|
||
"泰米尔语": "Tamil",
|
||
"乌尔都语": "Urdu",
|
||
"孟加拉语": "Bengali",
|
||
"波兰语": "Polish",
|
||
"荷兰语": "Dutch",
|
||
"罗马尼亚语": "Romanian",
|
||
"土耳其语": "Turkish",
|
||
"高棉语": "Khmer",
|
||
"老挝语": "Lao",
|
||
"粤语": "Cantonese",
|
||
"捷克语": "Czech",
|
||
"希腊语": "Greek",
|
||
"瑞典语": "Swedish",
|
||
"匈牙利语": "Hungarian",
|
||
"丹麦语": "Danish",
|
||
"芬兰语": "Finnish",
|
||
"乌克兰语": "Ukrainian",
|
||
"保加利亚语": "Bulgarian",
|
||
"塞尔维亚语": "Serbian",
|
||
"泰卢固语": "Telugu",
|
||
"南非荷兰语": "Afrikaans",
|
||
"亚美尼亚语": "Armenian",
|
||
"阿萨姆语": "Assamese",
|
||
"阿斯图里亚斯语": "Asturian",
|
||
"巴斯克语": "Basque",
|
||
"白俄罗斯语": "Belarusian",
|
||
"波斯尼亚语": "Bosnian",
|
||
"加泰罗尼亚语": "Catalan",
|
||
"宿务语": "Cebuano",
|
||
"克罗地亚语": "Croatian",
|
||
"埃及阿拉伯语": "Egyptian Arabic",
|
||
"爱沙尼亚语": "Estonian",
|
||
"加利西亚语": "Galician",
|
||
"格鲁吉亚语": "Georgian",
|
||
"古吉拉特语": "Gujarati",
|
||
"冰岛语": "Icelandic",
|
||
"爪哇语": "Javanese",
|
||
"卡纳达语": "Kannada",
|
||
"哈萨克语": "Kazakh",
|
||
"拉脱维亚语": "Latvian",
|
||
"立陶宛语": "Lithuanian",
|
||
"卢森堡语": "Luxembourgish",
|
||
"马其顿语": "Macedonian",
|
||
"马加希语": "Maithili",
|
||
"马耳他语": "Maltese",
|
||
"马拉地语": "Marathi",
|
||
"美索不达米亚阿拉伯语": "Mesopotamian Arabic",
|
||
"摩洛哥阿拉伯语": "Moroccan Arabic",
|
||
"内志阿拉伯语": "Najdi Arabic",
|
||
"尼泊尔语": "Nepali",
|
||
"北阿塞拜疆语": "North Azerbaijani",
|
||
"北黎凡特阿拉伯语": "North Levantine Arabic",
|
||
"北乌兹别克语": "Northern Uzbek",
|
||
"书面语挪威语": "Norwegian Bokmål",
|
||
"新挪威语": "Norwegian Nynorsk",
|
||
"奥克语": "Occitan",
|
||
"奥里亚语": "Odia",
|
||
"邦阿西楠语": "Pangasinan",
|
||
"西西里语": "Sicilian",
|
||
"信德语": "Sindhi",
|
||
"僧伽罗语": "Sinhala",
|
||
"斯洛伐克语": "Slovak",
|
||
"斯洛文尼亚语": "Slovenian",
|
||
"南黎凡特阿拉伯语": "South Levantine Arabic",
|
||
"斯瓦希里语": "Swahili",
|
||
"他加禄语": "Tagalog",
|
||
"塔伊兹-亚丁阿拉伯语": "Ta'izzi-Adeni Arabic",
|
||
"托斯克阿尔巴尼亚语": "Tosk Albanian",
|
||
"突尼斯阿拉伯语": "Tunisian Arabic",
|
||
"威尼斯语": "Venetian",
|
||
"瓦莱语": "Waray",
|
||
"威尔士语": "Welsh",
|
||
"西波斯语": "Western Persian",
|
||
# English punctuation/variant aliases
|
||
"norwegian bokmal": "Norwegian Bokmål",
|
||
"ta'izzi-adeni arabic": "Ta'izzi-Adeni Arabic",
|
||
}
|
||
|
||
class Agent:
|
||
|
||
def __init__(self, config: AgentConfig):
|
||
self.baseurl = config.base_url.strip()
|
||
if self.baseurl.endswith("/"):
|
||
self.baseurl = self.baseurl[:-1]
|
||
self.domain = urlparse(self.baseurl).netloc.strip()
|
||
self.key = config.api_key.strip() if config.api_key else "xx"
|
||
self.model_id = config.model_id.strip()
|
||
self.system_prompt = ""
|
||
self.temperature = config.temperature
|
||
self.max_concurrent = config.concurrent
|
||
self.timeout = httpx.Timeout(connect=5, read=config.timeout, write=300, pool=10)
|
||
self.thinking = config.thinking
|
||
self.logger = config.logger
|
||
self.total_error_counter = TotalErrorCounter(logger=self.logger)
|
||
self.unresolved_error_lock = Lock()
|
||
self.unresolved_error_count = 0
|
||
self.token_counter = TokenCounter(logger=self.logger)
|
||
self.retry = config.retry
|
||
self.system_proxy_enable = config.system_proxy_enable
|
||
|
||
# 新增:初始化速率限制器
|
||
self.rate_limiter = RateLimiter(rpm=config.rpm, tpm=config.tpm)
|
||
|
||
self.provider = config.provider if config.provider is not None else get_provider_by_domain(self.domain)
|
||
self.is_mt_mode = "mt" in self.model_id.lower()
|
||
self.mt_source_lang = config.source_lang if config.source_lang else "auto"
|
||
self.mt_target_lang = getattr(config, "to_lang", None)
|
||
self.mt_domains = getattr(config, "custom_prompt", None)
|
||
self.mt_glossary_dict = getattr(config, "glossary_dict", None)
|
||
|
||
def _estimate_tokens(self, text: str) -> int:
|
||
"""
|
||
改进的纯 Python 估算,适配更多语言。
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
total_len = len(text)
|
||
|
||
# 统计复杂字符数量 (CJK, 俄语, 阿拉伯语等)
|
||
complex_char_count = len(_COMPLEX_SCRIPT_PATTERN.findall(text))
|
||
|
||
# 简单的 ASCII 或拉丁字符
|
||
simple_char_count = total_len - complex_char_count
|
||
|
||
# 权重设定:
|
||
# 复杂字符:保守估计 1.0 (GPT-4o 对中文优化很好,约为0.6-0.7,但为了限流安全,建议设高一点)
|
||
# 简单字符:0.3 (英文平均 1个token ≈ 3.5字符)
|
||
# 额外:加上消息的固定开销 (Message Overhead),通常每条消息有 3-4 个 token 的系统开销
|
||
|
||
estimated = (complex_char_count * 1.0) + (simple_char_count * 0.3)
|
||
|
||
# 向上取整
|
||
return int(estimated) + 1
|
||
|
||
def _add_thinking_mode(self, data: dict):
|
||
thinking_mode_result = get_thinking_mode(self.provider, data.get("model"))
|
||
if thinking_mode_result is None:
|
||
return
|
||
field_thinking, val_enable, val_disable = thinking_mode_result
|
||
if self.thinking == "enable":
|
||
data[field_thinking] = val_enable
|
||
elif self.thinking == "disable":
|
||
data[field_thinking] = val_disable
|
||
|
||
def _normalize_mt_lang(self, lang: str | None) -> str | None:
|
||
if lang is None:
|
||
return None
|
||
lang_text = str(lang).strip()
|
||
if not lang_text:
|
||
return None
|
||
key = _normalize_mt_lang_key(lang_text)
|
||
if key in _MT_LANG_BY_CODE:
|
||
return _MT_LANG_BY_CODE[key]
|
||
if key in _MT_LANG_BY_NAME:
|
||
return _MT_LANG_BY_NAME[key]
|
||
if key in _MT_LANG_ALIASES:
|
||
return _MT_LANG_ALIASES[key]
|
||
return lang_text
|
||
|
||
def _build_mt_translation_options(self, prompt: str = "") -> dict:
|
||
translation_options = {}
|
||
|
||
source_lang = self._normalize_mt_lang(self.mt_source_lang)
|
||
if source_lang:
|
||
translation_options["source_lang"] = source_lang
|
||
|
||
target_lang = self._normalize_mt_lang(self.mt_target_lang)
|
||
if target_lang:
|
||
translation_options["target_lang"] = target_lang
|
||
|
||
domains = str(self.mt_domains).strip() if self.mt_domains is not None else ""
|
||
if domains:
|
||
translation_options["domains"] = domains
|
||
|
||
if self.mt_glossary_dict:
|
||
terminology_list = [
|
||
{"source": src, "target": tgt}
|
||
for src, tgt in self.mt_glossary_dict.items()
|
||
if src and tgt and src.lower() in prompt.lower()
|
||
]
|
||
if terminology_list:
|
||
translation_options["terms"] = terminology_list
|
||
|
||
return translation_options
|
||
|
||
def _build_mt_user_prompt(self, prompt: str, system_prompt: str) -> str:
|
||
# MT模式下,直接返回原始prompt,不添加任何system prompt
|
||
# MT模型会把整个user prompt当作待翻译内容
|
||
return prompt
|
||
|
||
def _prepare_request_data(
|
||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9, json_format=False
|
||
):
|
||
if temperature is None:
|
||
temperature = self.temperature
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.key}",
|
||
}
|
||
|
||
if self.is_mt_mode:
|
||
data = {
|
||
"model": self.model_id,
|
||
"messages": [
|
||
{"role": "user", "content": self._build_mt_user_prompt(prompt, system_prompt)},
|
||
],
|
||
}
|
||
translation_options = self._build_mt_translation_options(prompt=prompt)
|
||
if translation_options:
|
||
data["translation_options"] = translation_options
|
||
return headers, data
|
||
|
||
data = {
|
||
"model": self.model_id,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
"temperature": temperature,
|
||
"top_p": top_p,
|
||
}
|
||
if self.thinking != "default":
|
||
self._add_thinking_mode(data)
|
||
if json_format:
|
||
data["response_format"] = {"type": "json_object"}
|
||
return headers, data
|
||
|
||
async def _continue_fetch_async(
|
||
self,
|
||
client: httpx.AsyncClient,
|
||
prompt: str,
|
||
system_prompt: str,
|
||
force_json: bool,
|
||
pre_send_handler: PreSendHandlerType,
|
||
result_handler: ResultHandlerType,
|
||
error_result_handler: ErrorResultHandlerType,
|
||
retry_count: int,
|
||
accumulated_result: str = "",
|
||
continue_count: int = 0,
|
||
) -> Any:
|
||
"""
|
||
当 finish_reason 为 length 时,继续获取剩余内容。
|
||
注意:很多 API 并不支持这种"继续获取"模式,可能直接返回 stop 或不返回 length。
|
||
本方法具有退化机制:如果 API 不支持继续获取,会返回已累计的结果。
|
||
最多继续获取 MAX_CONTINUE_FETCHES 次,防止无限循环。
|
||
"""
|
||
if continue_count >= MAX_CONTINUE_FETCHES:
|
||
self.logger.warning(
|
||
f"已达到最大继续获取次数 ({MAX_CONTINUE_FETCHES}),返回已累计结果 ({len(accumulated_result)} 字符)")
|
||
return (
|
||
accumulated_result
|
||
if result_handler is None
|
||
else result_handler(accumulated_result, prompt, self.logger)
|
||
)
|
||
|
||
self.logger.info(
|
||
f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符, 第 {continue_count + 1}/{MAX_CONTINUE_FETCHES} 次)...")
|
||
|
||
# 构造继续请求的提示
|
||
# 关键:告知模型我们已经获取了部分内容,请继续完成
|
||
continue_prompt = f"{prompt}\n\n[系统提示:请继续完成之前的响应。之前已输出内容为:\n---\n{accumulated_result}\n---\n请从中断处继续输出剩余内容。]"
|
||
|
||
if pre_send_handler:
|
||
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
|
||
|
||
# 速率限制检查
|
||
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt)
|
||
await self.rate_limiter.acquire_async(tokens=estimated_tokens)
|
||
|
||
headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json)
|
||
|
||
try:
|
||
response = await client.post(
|
||
f"{self.baseurl}/chat/completions",
|
||
json=data,
|
||
headers=headers,
|
||
timeout=self.timeout,
|
||
)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
|
||
# 安全提取 choices 和 content
|
||
choices = response_data.get("choices", [])
|
||
if not choices:
|
||
self.logger.error(f"API响应中未找到 choices 字段")
|
||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||
|
||
choice = choices[0]
|
||
finish_reason = choice.get("finish_reason", None)
|
||
message = choice.get("message", {})
|
||
additional_result = message.get("content", "")
|
||
|
||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||
extract_token_info(response_data)
|
||
)
|
||
self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens)
|
||
|
||
# 累加结果
|
||
accumulated_result += additional_result
|
||
|
||
# 如果仍然是 length,继续获取(限制最大轮数防止无限循环)
|
||
if finish_reason == "length":
|
||
return await self._continue_fetch_async(
|
||
client=client,
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
retry_count=retry_count,
|
||
accumulated_result=accumulated_result,
|
||
continue_count=continue_count + 1,
|
||
)
|
||
|
||
# 非 length 结束,返回累加结果
|
||
return (
|
||
accumulated_result
|
||
if result_handler is None
|
||
else result_handler(accumulated_result, prompt, self.logger)
|
||
)
|
||
|
||
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
|
||
self.logger.error(f"继续获取内容失败: {repr(e)}")
|
||
# 退化:返回已获取的部分结果,而不是报错
|
||
if accumulated_result:
|
||
self.logger.warning(f"API不支持继续获取,返回已获取的部分结果 ({len(accumulated_result)} 字符)")
|
||
return (
|
||
accumulated_result
|
||
if result_handler is None
|
||
else result_handler(accumulated_result, prompt, self.logger)
|
||
)
|
||
# 如果没有部分结果,调用错误处理器
|
||
return (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
|
||
async def send_async(
|
||
self,
|
||
client: httpx.AsyncClient,
|
||
prompt: str,
|
||
system_prompt: None | str = None,
|
||
retry=True,
|
||
retry_count=0,
|
||
force_json=False,
|
||
pre_send_handler: PreSendHandlerType = None,
|
||
result_handler: ResultHandlerType = None,
|
||
error_result_handler: ErrorResultHandlerType = None,
|
||
best_partial_result: dict | None = None,
|
||
) -> Any:
|
||
if system_prompt is None:
|
||
system_prompt = self.system_prompt
|
||
if pre_send_handler:
|
||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||
|
||
# 新增:速率限制检查
|
||
# 计算估算的 tokens (system + user)
|
||
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(prompt)
|
||
# 等待配额
|
||
await self.rate_limiter.acquire_async(tokens=estimated_tokens)
|
||
|
||
headers, data = self._prepare_request_data(prompt, system_prompt, json_format=force_json)
|
||
should_retry = False
|
||
is_hard_error = False
|
||
current_partial_result = None
|
||
input_tokens = 0
|
||
output_tokens = 0
|
||
|
||
try:
|
||
response = await client.post(
|
||
f"{self.baseurl}/chat/completions",
|
||
json=data,
|
||
headers=headers,
|
||
timeout=self.timeout,
|
||
)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
|
||
# 检查 finish_reason
|
||
choices = response_data.get("choices", [])
|
||
if not choices:
|
||
self.logger.error(f"API响应中未找到 choices 字段")
|
||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||
|
||
finish_reason = choices[0].get("finish_reason", None)
|
||
result = choices[0].get("message", {}).get("content", "")
|
||
|
||
# 处理不同的 finish_reason
|
||
if finish_reason == "stop":
|
||
# 正常结束
|
||
pass
|
||
elif finish_reason == "length":
|
||
# 长度限制,尝试继续获取
|
||
self.logger.warning(f"响应因长度限制被截断,尝试继续获取...")
|
||
return await self._continue_fetch_async(
|
||
client=client,
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
retry_count=retry_count,
|
||
accumulated_result=result,
|
||
)
|
||
elif finish_reason in ("tool_calls", "function_call"):
|
||
# 工具调用场景,当前代码可能不支持,直接返回已获取结果
|
||
self.logger.warning(f"finish_reason 为 '{finish_reason}',当前不支持工具调用,返回已获取内容")
|
||
return result if result else (
|
||
prompt if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
elif finish_reason == "content_filter":
|
||
# 内容被过滤
|
||
self.logger.error(f"响应内容被过滤")
|
||
raise ValueError("内容被过滤")
|
||
elif finish_reason is None:
|
||
# 某些 API 可能不返回 finish_reason,将其视为正常结束
|
||
self.logger.warning(f"API未返回 finish_reason,视为正常结束")
|
||
else:
|
||
# 其他未知的 finish_reason,记录警告并返回结果
|
||
self.logger.warning(f"未知的 finish_reason: '{finish_reason}',返回已获取内容")
|
||
|
||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||
extract_token_info(response_data)
|
||
)
|
||
|
||
self.token_counter.add(
|
||
input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||
)
|
||
|
||
if retry_count > 0:
|
||
self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。")
|
||
|
||
return (
|
||
result
|
||
if result_handler is None
|
||
else result_handler(result, prompt, self.logger)
|
||
)
|
||
|
||
except AgentResultError as e:
|
||
self.logger.error(f"AI返回结果有误: {e}")
|
||
should_retry = True
|
||
except PartialAgentResultError as e:
|
||
self.logger.error(f"收到部分返回结果,将尝试重试: {e}")
|
||
current_partial_result = e.partial_result
|
||
should_retry = True
|
||
if e.append_prompt:
|
||
prompt += e.append_prompt
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
self.logger.error(
|
||
f"AI请求HTTP状态错误 (async): {e.response.status_code} - {e.response.text}"
|
||
)
|
||
should_retry = True
|
||
is_hard_error = True
|
||
# 如果是因为 Rate Limit (429) 错误,最好在这里多睡一会儿,虽然我们有了本地 Limiter
|
||
if e.response.status_code == 429:
|
||
await asyncio.sleep(5)
|
||
|
||
except httpx.RequestError as e:
|
||
self.logger.error(f"AI请求连接错误 (async): {repr(e)}")
|
||
should_retry = True
|
||
is_hard_error = True
|
||
except (KeyError, IndexError, ValueError, json.JSONDecodeError) as e:
|
||
self.logger.error(f"AI响应格式或值错误 (async), 将尝试重试: {repr(e)}")
|
||
should_retry = True
|
||
is_hard_error = True
|
||
|
||
if current_partial_result:
|
||
best_partial_result = current_partial_result
|
||
|
||
if should_retry and retry and retry_count < self.retry:
|
||
if is_hard_error:
|
||
if retry_count == 0:
|
||
if self.total_error_counter.add():
|
||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||
with self.unresolved_error_lock:
|
||
self.unresolved_error_count += 1
|
||
return (
|
||
best_partial_result
|
||
if best_partial_result
|
||
else (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
)
|
||
elif self.total_error_counter.reach_limit():
|
||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||
with self.unresolved_error_lock:
|
||
self.unresolved_error_count += 1
|
||
return (
|
||
best_partial_result
|
||
if best_partial_result
|
||
else (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
)
|
||
|
||
self.logger.info(f"正在重试第 {retry_count + 1}/{self.retry} 次...")
|
||
# 指数退避
|
||
await asyncio.sleep(0.5 * (2 ** retry_count))
|
||
return await self.send_async(
|
||
client,
|
||
prompt,
|
||
system_prompt,
|
||
retry=True,
|
||
retry_count=retry_count + 1,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
best_partial_result=best_partial_result,
|
||
)
|
||
else:
|
||
if should_retry:
|
||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||
with self.unresolved_error_lock:
|
||
self.unresolved_error_count += 1
|
||
|
||
if best_partial_result:
|
||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||
return best_partial_result
|
||
|
||
return (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
|
||
async def send_prompts_async(
|
||
self,
|
||
prompts: list[str],
|
||
system_prompt: str | None = None,
|
||
max_concurrent: int | None = None,
|
||
force_json=False,
|
||
pre_send_handler: PreSendHandlerType = None,
|
||
result_handler: ResultHandlerType = None,
|
||
error_result_handler: ErrorResultHandlerType = None,
|
||
) -> list[Any]:
|
||
max_concurrent = (
|
||
self.max_concurrent if max_concurrent is None else max_concurrent
|
||
)
|
||
total = len(prompts)
|
||
rpm_info = f", RPM:{self.rate_limiter.rpm}" if self.rate_limiter.rpm else ""
|
||
tpm_info = f", TPM:{self.rate_limiter.tpm}" if self.rate_limiter.tpm else ""
|
||
|
||
self.logger.info(
|
||
f"provider:{self.provider},base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent}{rpm_info}{tpm_info},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{force_json}"
|
||
)
|
||
self.logger.info(f"预计发送{total}个请求")
|
||
|
||
self.total_error_counter.max_errors_count = (
|
||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||
)
|
||
|
||
self.unresolved_error_count = 0
|
||
self.token_counter.reset()
|
||
|
||
count = 0
|
||
semaphore = asyncio.Semaphore(max_concurrent)
|
||
tasks = []
|
||
|
||
proxies = get_httpx_proxies(asyn=True) if self.system_proxy_enable else None
|
||
|
||
limits = httpx.Limits(
|
||
max_connections=self.max_concurrent * 2,
|
||
max_keepalive_connections=self.max_concurrent,
|
||
)
|
||
|
||
async with httpx.AsyncClient(
|
||
trust_env=False, mounts=proxies, verify=False, limits=limits
|
||
) as client:
|
||
async def send_with_semaphore(p_text: str):
|
||
async with semaphore:
|
||
# 注意:我们在 semaphore 内部调用 send_async
|
||
# send_async 内部会调用 rate_limiter.acquire_async
|
||
# 这样可以防止并发过高,同时 rate_limiter 防止频率过快
|
||
result = await self.send_async(
|
||
client=client,
|
||
prompt=p_text,
|
||
system_prompt=system_prompt,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
)
|
||
nonlocal count
|
||
count += 1
|
||
self.logger.info(f"协程-已完成{count}/{total}")
|
||
return result
|
||
|
||
for p_text in prompts:
|
||
task = asyncio.create_task(send_with_semaphore(p_text))
|
||
tasks.append(task)
|
||
|
||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||
|
||
self.logger.info(
|
||
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
|
||
)
|
||
|
||
token_stats = self.token_counter.get_stats()
|
||
self.logger.info(
|
||
f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), "
|
||
f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), "
|
||
f"总计: {token_stats['total_tokens'] / 1000:.2f}K"
|
||
)
|
||
|
||
return results
|
||
|
||
def _continue_fetch(
|
||
self,
|
||
client: httpx.Client,
|
||
prompt: str,
|
||
system_prompt: str,
|
||
force_json: bool,
|
||
pre_send_handler,
|
||
result_handler,
|
||
error_result_handler,
|
||
retry_count: int,
|
||
accumulated_result: str = "",
|
||
continue_count: int = 0,
|
||
) -> Any:
|
||
"""
|
||
当 finish_reason 为 length 时,继续获取剩余内容(同步版本)。
|
||
注意:很多 API 并不支持这种"继续获取"模式,可能直接返回 stop 或不返回 length。
|
||
本方法具有退化机制:如果 API 不支持继续获取,会返回已累计的结果。
|
||
最多继续获取 MAX_CONTINUE_FETCHES 次,防止无限循环。
|
||
"""
|
||
if continue_count >= MAX_CONTINUE_FETCHES:
|
||
self.logger.warning(
|
||
f"已达到最大继续获取次数 ({MAX_CONTINUE_FETCHES}),返回已累计结果 ({len(accumulated_result)} 字符)")
|
||
return (
|
||
accumulated_result
|
||
if result_handler is None
|
||
else result_handler(accumulated_result, prompt, self.logger)
|
||
)
|
||
|
||
self.logger.info(
|
||
f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符, 第 {continue_count + 1}/{MAX_CONTINUE_FETCHES} 次)...")
|
||
|
||
# 构造继续请求的提示
|
||
continue_prompt = f"{prompt}\n\n[系统提示:请继续完成之前的响应。之前已输出内容为:\n---\n{accumulated_result}\n---\n请从中断处继续输出剩余内容。]"
|
||
|
||
if pre_send_handler:
|
||
system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt)
|
||
|
||
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt)
|
||
self.rate_limiter.acquire_sync(tokens=estimated_tokens)
|
||
|
||
headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json)
|
||
|
||
try:
|
||
response = client.post(
|
||
f"{self.baseurl}/chat/completions",
|
||
json=data,
|
||
headers=headers,
|
||
timeout=self.timeout,
|
||
)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
|
||
# 安全提取 choices 和 content
|
||
choices = response_data.get("choices", [])
|
||
if not choices:
|
||
self.logger.error(f"API响应中未找到 choices 字段")
|
||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||
|
||
choice = choices[0]
|
||
finish_reason = choice.get("finish_reason", None)
|
||
message = choice.get("message", {})
|
||
additional_result = message.get("content", "")
|
||
|
||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||
extract_token_info(response_data)
|
||
)
|
||
self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens)
|
||
|
||
accumulated_result += additional_result
|
||
|
||
if finish_reason == "length":
|
||
return self._continue_fetch(
|
||
client=client,
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
retry_count=retry_count,
|
||
accumulated_result=accumulated_result,
|
||
continue_count=continue_count + 1,
|
||
)
|
||
|
||
return (
|
||
accumulated_result
|
||
if result_handler is None
|
||
else result_handler(accumulated_result, prompt, self.logger)
|
||
)
|
||
|
||
except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e:
|
||
self.logger.error(f"继续获取内容失败: {repr(e)}")
|
||
# 退化:返回已获取的部分结果,而不是报错
|
||
if accumulated_result:
|
||
self.logger.warning(f"API不支持继续获取,返回已获取的部分结果 ({len(accumulated_result)} 字符)")
|
||
return (
|
||
accumulated_result
|
||
if result_handler is None
|
||
else result_handler(accumulated_result, prompt, self.logger)
|
||
)
|
||
return (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
|
||
def send(
|
||
self,
|
||
client: httpx.Client,
|
||
prompt: str,
|
||
system_prompt: None | str = None,
|
||
retry=True,
|
||
retry_count=0,
|
||
force_json=False,
|
||
pre_send_handler=None,
|
||
result_handler=None,
|
||
error_result_handler=None,
|
||
best_partial_result: dict | None = None,
|
||
) -> Any:
|
||
if system_prompt is None:
|
||
system_prompt = self.system_prompt
|
||
if pre_send_handler:
|
||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||
|
||
# 新增:同步环境下的速率限制
|
||
estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(prompt)
|
||
self.rate_limiter.acquire_sync(tokens=estimated_tokens)
|
||
|
||
headers, data = self._prepare_request_data(prompt, system_prompt, json_format=force_json)
|
||
should_retry = False
|
||
is_hard_error = False
|
||
current_partial_result = None
|
||
|
||
try:
|
||
response = client.post(
|
||
f"{self.baseurl}/chat/completions",
|
||
json=data,
|
||
headers=headers,
|
||
timeout=self.timeout,
|
||
)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
|
||
# 检查 finish_reason
|
||
choices = response_data.get("choices", [])
|
||
if not choices:
|
||
self.logger.error(f"API响应中未找到 choices 字段")
|
||
raise ValueError("API响应格式错误:缺少 choices 字段")
|
||
|
||
finish_reason = choices[0].get("finish_reason", None)
|
||
result = choices[0].get("message", {}).get("content", "")
|
||
|
||
# 处理不同的 finish_reason
|
||
if finish_reason == "stop":
|
||
# 正常结束
|
||
pass
|
||
elif finish_reason == "length":
|
||
# 长度限制,尝试继续获取
|
||
self.logger.warning(f"响应因长度限制被截断,尝试继续获取...")
|
||
return self._continue_fetch(
|
||
client=client,
|
||
prompt=prompt,
|
||
system_prompt=system_prompt,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
retry_count=retry_count,
|
||
accumulated_result=result,
|
||
)
|
||
elif finish_reason in ("tool_calls", "function_call"):
|
||
# 工具调用场景,当前代码可能不支持,直接返回已获取结果
|
||
self.logger.warning(f"finish_reason 为 '{finish_reason}',当前不支持工具调用,返回已获取内容")
|
||
return result if result else (
|
||
prompt if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
elif finish_reason == "content_filter":
|
||
# 内容被过滤
|
||
self.logger.error(f"响应内容被过滤")
|
||
raise ValueError("内容被过滤")
|
||
elif finish_reason is None:
|
||
# 某些 API 可能不返回 finish_reason,将其视为正常结束
|
||
self.logger.warning(f"API未返回 finish_reason,视为正常结束")
|
||
else:
|
||
# 其他未知的 finish_reason,记录警告并返回结果
|
||
self.logger.warning(f"未知的 finish_reason: '{finish_reason}',返回已获取内容")
|
||
|
||
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (
|
||
extract_token_info(response_data)
|
||
)
|
||
|
||
self.token_counter.add(
|
||
input_tokens, cached_tokens, output_tokens, reasoning_tokens
|
||
)
|
||
|
||
if retry_count > 0:
|
||
self.logger.info(f"重试成功 (第 {retry_count}/{self.retry} 次尝试)。")
|
||
|
||
return (
|
||
result
|
||
if result_handler is None
|
||
else result_handler(result, prompt, self.logger)
|
||
)
|
||
except AgentResultError as e:
|
||
self.logger.error(f"AI返回结果有误: {e}")
|
||
should_retry = True
|
||
except PartialAgentResultError as e:
|
||
self.logger.error(f"收到部分翻译结果,将尝试重试: {e}")
|
||
current_partial_result = e.partial_result
|
||
should_retry = True
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
self.logger.error(
|
||
f"AI请求HTTP状态错误 (sync): {e.response.status_code} - {e.response.text}"
|
||
)
|
||
should_retry = True
|
||
is_hard_error = True
|
||
if e.response.status_code == 429:
|
||
time.sleep(5)
|
||
|
||
except httpx.RequestError as e:
|
||
self.logger.error(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
||
should_retry = True
|
||
is_hard_error = True
|
||
except (KeyError, IndexError, ValueError, json.JSONDecodeError) as e:
|
||
self.logger.error(f"AI响应格式或值错误 (sync), 将尝试重试: {repr(e)}")
|
||
should_retry = True
|
||
is_hard_error = True
|
||
|
||
if current_partial_result:
|
||
best_partial_result = current_partial_result
|
||
|
||
if should_retry and retry and retry_count < self.retry:
|
||
if is_hard_error:
|
||
if retry_count == 0:
|
||
if self.total_error_counter.add():
|
||
self.logger.error("错误次数过多,已达到上限,不再重试。")
|
||
with self.unresolved_error_lock:
|
||
self.unresolved_error_count += 1
|
||
return (
|
||
best_partial_result
|
||
if best_partial_result
|
||
else (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
)
|
||
elif self.total_error_counter.reach_limit():
|
||
self.logger.error("错误次数过多,已达到上限,不再为该请求重试。")
|
||
with self.unresolved_error_lock:
|
||
self.unresolved_error_count += 1
|
||
return (
|
||
best_partial_result
|
||
if best_partial_result
|
||
else (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
)
|
||
|
||
self.logger.info(f"正在重试第 {retry_count + 1}/{self.retry} 次...")
|
||
time.sleep(0.5 * (2 ** retry_count))
|
||
return self.send(
|
||
client,
|
||
prompt,
|
||
system_prompt,
|
||
retry=True,
|
||
retry_count=retry_count + 1,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
best_partial_result=best_partial_result,
|
||
)
|
||
else:
|
||
if should_retry:
|
||
self.logger.error(f"所有重试均失败,已达到重试次数上限。")
|
||
with self.unresolved_error_lock:
|
||
self.unresolved_error_count += 1
|
||
|
||
if best_partial_result:
|
||
self.logger.info("所有重试失败,但存在部分翻译结果,将使用该结果。")
|
||
return best_partial_result
|
||
|
||
return (
|
||
prompt
|
||
if error_result_handler is None
|
||
else error_result_handler(prompt, self.logger)
|
||
)
|
||
|
||
def _send_prompt_count(
|
||
self,
|
||
client: httpx.Client,
|
||
prompt: str,
|
||
system_prompt: None | str,
|
||
force_json,
|
||
count: PromptsCounter,
|
||
pre_send_handler,
|
||
result_handler,
|
||
error_result_handler
|
||
) -> Any:
|
||
# 该方法在 ThreadPoolExecutor 中运行
|
||
result = self.send(
|
||
client,
|
||
prompt,
|
||
system_prompt,
|
||
force_json=force_json,
|
||
pre_send_handler=pre_send_handler,
|
||
result_handler=result_handler,
|
||
error_result_handler=error_result_handler,
|
||
)
|
||
count.add()
|
||
return result
|
||
|
||
def send_prompts(
|
||
self,
|
||
prompts: list[str],
|
||
system_prompt: str | None = None,
|
||
json_format=False,
|
||
pre_send_handler: PreSendHandlerType = None,
|
||
result_handler: ResultHandlerType = None,
|
||
error_result_handler: ErrorResultHandlerType = None,
|
||
) -> list[Any]:
|
||
rpm_info = f", RPM:{self.rate_limiter.rpm}" if self.rate_limiter.rpm else ""
|
||
tpm_info = f", TPM:{self.rate_limiter.tpm}" if self.rate_limiter.tpm else ""
|
||
|
||
self.logger.info(
|
||
f"provider:{self.provider},base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent}{rpm_info}{tpm_info},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{json_format}"
|
||
)
|
||
self.logger.info(
|
||
f"预计发送{len(prompts)}个请求"
|
||
)
|
||
self.total_error_counter.max_errors_count = (
|
||
len(prompts) // MAX_REQUESTS_PER_ERROR
|
||
)
|
||
|
||
self.unresolved_error_count = 0
|
||
self.token_counter.reset()
|
||
|
||
counter = PromptsCounter(len(prompts), self.logger)
|
||
|
||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||
json_formats = itertools.repeat(json_format, len(prompts))
|
||
counters = itertools.repeat(counter, len(prompts))
|
||
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
|
||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||
limits = httpx.Limits(
|
||
max_connections=self.max_concurrent * 2,
|
||
max_keepalive_connections=self.max_concurrent,
|
||
)
|
||
proxies = get_httpx_proxies(asyn=False) if self.system_proxy_enable else None
|
||
|
||
with httpx.Client(
|
||
trust_env=False, mounts=proxies, verify=False, limits=limits
|
||
) as client:
|
||
clients = itertools.repeat(client, len(prompts))
|
||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||
results_iterator = executor.map(
|
||
self._send_prompt_count,
|
||
clients,
|
||
prompts,
|
||
system_prompts,
|
||
json_formats,
|
||
counters,
|
||
pre_send_handlers,
|
||
result_handlers,
|
||
error_result_handlers,
|
||
)
|
||
output_list = list(results_iterator)
|
||
|
||
self.logger.info(
|
||
f"所有请求处理完毕。未解决的错误总数: {self.unresolved_error_count}"
|
||
)
|
||
|
||
token_stats = self.token_counter.get_stats()
|
||
self.logger.info(
|
||
f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), "
|
||
f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), "
|
||
f"总计: {token_stats['total_tokens'] / 1000:.2f}K"
|
||
)
|
||
|
||
return output_list
|