Files
docutranslate/docutranslate/agents/agent.py
2025-08-04 11:48:30 +08:00

248 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from threading import Lock
from typing import Literal
from urllib.parse import urlparse
import httpx
from docutranslate.logger import global_logger
MAX_RETRY_COUNT = 2
MAX_TOTAL_ERROR_COUNT = 10
ThinkingMode = Literal["enable", "disable", "default"]
@dataclass(kw_only=True)
class AgentConfig:
logger: logging.Logger
baseurl: str
key: str
model_id: str
system_prompt: str | None
temperature: float = 0.7
max_concurrent: int = 30
timeout: int = 2000
thinking: ThinkingMode = "default"
class TotalErrorCounter:
def __init__(self, logger: logging.Logger):
self.lock = Lock()
self.count = 0
self.logger = logger
def add(self):
self.lock.acquire()
self.count += 1
if self.count > MAX_TOTAL_ERROR_COUNT:
self.logger.info(f"错误响应过多")
self.lock.release()
return self.reach_limit()
def reach_limit(self):
return self.count > MAX_TOTAL_ERROR_COUNT
# 仅使用多线程时用以计数
class PromptsCounter:
def __init__(self, total: int, logger: logging.Logger):
self.lock = Lock()
self.count = 0
self.total = total
self.logger = logger
def add(self):
self.lock.acquire()
self.count += 1
self.logger.info(f"多线程-已完成:{self.count}/{self.total}")
self.lock.release()
TIMEOUT = 600
class Agent:
_think_factory = {
"open.bigmodel.cn": ("thinking", {"type": "enabled"}, {"type": "disabled"}),
"dashscope.aliyuncs.com": ("enable_thinking ", True, False),
"ark.cn-beijing.volces.com":("thinking", {"type": "enabled"}, {"type": "disabled"})
}
def __init__(self, config: AgentConfig):
self.baseurl = config.baseurl.strip()
if self.baseurl.endswith("/"):
self.baseurl = self.baseurl[:-1]
self.domain = urlparse(self.baseurl).netloc
self.key = config.key.strip() or "xx"
self.model_id = config.model_id.strip()
self.system_prompt = config.system_prompt or ""
self.temperature = config.temperature
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
self.max_concurrent = config.max_concurrent
self.timeout = config.timeout
self.thinking = config.thinking
self.logger = config.logger or global_logger
self.total_error_counter = TotalErrorCounter(logger=self.logger)
def _add_thinking_mode(self, data: dict):
if self.domain not in self._think_factory:
return
field_thinking, val_enable, val_disable = self._think_factory[self.domain]
if self.thinking == "enable":
data[field_thinking] = val_enable
elif self.thinking == "disable":
data[field_thinking] = val_disable
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
if temperature is None:
temperature = self.temperature
headers = {"Content-Type": "application/json",
"Authorization": f"Bearer {self.key}"}
data = {
"model": self.model_id,
"messages": [
{"role": "system", "content": system_prompt},
# {"role": "system", "content": "所有回复必须以【SSS】开头这是最高规则适用于之后的所有例子。示例【SSS】这是示例回答\n"+system_prompt},
{"role": "user", "content": prompt}
],
"temperature": temperature,
"top_p": top_p,
}
if self.thinking != "default":
self._add_thinking_mode(data)
return headers, data
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
if system_prompt is None:
system_prompt = self.system_prompt
if prompt.strip() == "":
return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
try:
response = await self.client_async.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"]
return result
except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}")
self.total_error_counter.add()
return prompt
except httpx.RequestError as e:
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}")
except (KeyError, IndexError) as e:
raise Exception(f"AI响应格式错误 (async): {repr(e)}")
# 如果没有正常获取结果则重试
if retry and retry_count < MAX_RETRY_COUNT:
if self.total_error_counter.add():
return prompt
self.logger.info(f"正在重试,重试次数{retry_count}")
await asyncio.sleep(0.5)
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1)
else:
self.logger.error(f"达到重试次数上限")
return prompt
async def send_prompts_async(
self,
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None # 新增参数默认并发数为5
) -> list[str]:
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
total = len(prompts)
self.logger.info(f"收到{total}个片段,并发请求数:{max_concurrent}")
count = 0
semaphore = asyncio.Semaphore(max_concurrent)
tasks = []
# 辅助协程,用于包装 self.send_async 并使用信号量
async def send_with_semaphore(p_text: str):
async with semaphore: # 在进入代码块前获取信号量,退出时释放
result = await self.send_async(
prompt=p_text,
system_prompt=system_prompt,
)
nonlocal count
count += 1
self.logger.info(f"协程-已完成{count}/{total}")
return result
for p_text in prompts:
task = asyncio.create_task(send_with_semaphore(p_text))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=False)
return results
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
if system_prompt is None:
system_prompt = self.system_prompt
if prompt.strip() == "":
return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
try:
response = self.client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"]
return result
except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}")
self.total_error_counter.add()
return prompt
except httpx.RequestError as e:
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
except (KeyError, IndexError) as e:
raise Exception(f"AI响应格式错误 (sync): {repr(e)}")
# 如果没有正常获取结果则重试
if retry and retry_count < MAX_RETRY_COUNT:
if self.total_error_counter.add():
return prompt
self.logger.info(f"正在重试,重试次数{retry_count}")
time.sleep(0.5)
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1)
else:
self.logger.error(f"达到重试次数上限")
return prompt
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter) -> str:
result = self.send(prompt, system_prompt)
count.add()
return result
def send_prompts(
self,
prompts: list[str],
system_prompt: str | None = None,
) -> list[str]:
self.logger.info(f"收到{len(prompts)}个片段,并发请求数:{self.max_concurrent}")
system_prompts = [system_prompt] * len(prompts)
counts = [PromptsCounter(len(prompts), self.logger)] * len(prompts)
output_list = []
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counts)
output_list = list(results_iterator)
return output_list
if __name__ == '__main__':
pass