Files
docutranslate/docutranslate/utils/agent_utils.py
2025-05-08 19:49:31 +08:00

118 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import re
import httpx
TIMEOUT=250
class Agent:
def __init__(self, baseurl="", key="", model_id="", system_prompt="", temperature=0.7, max_concurrent=5):
self.baseurl = baseurl
self.key = key
self.model_id = model_id
self.system_prompt = system_prompt
self.temperature = temperature
# self.client=httpx.Client()
self.client_async = httpx.AsyncClient()
self.max_concurrent = max_concurrent
def _prepare_request_data(self, prompt:str, system_prompt:str, temperature=None, top_p=0.9):
if temperature is None:
temperature = self.temperature
headers = {"Content-Type": "application/json",
"Authorization": f"Bearer {self.key}"}
data = {
"model": self.model_id,
"messages": [
{"role": "system", "content": "所有回复必须严格以<<AAA>>开头,包括简短回答。示例:\n<<AAA>>这是示例回答\n"+system_prompt},
{"role": "user", "content": prompt}
],
"temperature": temperature,
"top_p": top_p
}
return headers, data
# def send_prompt(self,prompt,system_prompt=None,timeout=TIMEOUT):
# if system_prompt is None:
# system_prompt=self.system_prompt
# headers,data=self._prepare_request_data(prompt,system_prompt)
# response=self.client.post(f"{self.baseurl}/chat/completions",json=data,headers=headers,timeout=timeout)
# response.raise_for_status()
# return response.json()["choices"][0]["message"]["content"].lstrip()
async def send_async(self, prompt: str, system_prompt: None | str = None, timeout: int = TIMEOUT) -> str:
if system_prompt is None:
system_prompt = self.system_prompt
"""Sends a single prompt asynchronously."""
headers, data = self._prepare_request_data(prompt, system_prompt)
try:
response = await self.client_async.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=timeout
)
response.raise_for_status()
result=response.json()["choices"][0]["message"]["content"]
pattern=r"<<AAA>>(.*)"
match= re.search(pattern,result, re.DOTALL)
if match is None:
print("检测开头<<AAA>失败")
else:
result=match.group(1)
return result
except httpx.HTTPStatusError as e:
raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
except httpx.RequestError as e:
raise Exception(f"AI请求连接错误 (async): {e}") from e
except (KeyError, IndexError) as e:
raise Exception(f"AI响应格式错误 (async): {e}") from e
async def send_prompts_async(
self,
prompts: list[str],
system_prompt: str | None = None,
timeout: int = TIMEOUT,
max_concurrent: int = 5 # 新增参数默认并发数为5
) -> list[str]:
total = len(prompts)
count = 0
"""
Sends multiple prompts asynchronously, limiting concurrent requests.
"""
semaphore = asyncio.Semaphore(max_concurrent)
tasks = []
# 辅助协程,用于包装 self.send_async 并使用信号量
async def send_with_semaphore(p_text: str):
async with semaphore: # 在进入代码块前获取信号量,退出时释放
result = await self.send_async(
prompt=p_text,
system_prompt=system_prompt,
timeout=TIMEOUT
)
nonlocal count
count += 1
print(f"进行到{count}/{total}")
return result
for p_text in prompts:
task = asyncio.create_task(send_with_semaphore(p_text))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=False)
return results
def send_prompts(
self,
prompts: list[str],
system_prompt: str | None = None,
timeout: int = TIMEOUT,
) -> list[str]:
result = asyncio.run(self.send_prompts_async(prompts, system_prompt, timeout, self.max_concurrent))
return result
if __name__ == '__main__':
pass