Files
docutranslate/docutranslate/agents/agent_sync.py

90 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import TypedDict
from docutranslate.logger import translater_logger
import httpx
class AgentArgs(TypedDict, total=False):
baseurl: str
key: str
model_id: str
system_prompt: str
temperature: float
max_concurrent: int
timeout: int
TIMEOUT = 500
class Agent:
def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
max_concurrent=6, timeout: int = TIMEOUT):
self.baseurl = baseurl.strip()
self.key = key.strip()
self.model_id = model_id.strip()
self.system_prompt = system_prompt
self.temperature = temperature
self.client = httpx.Client()
self.max_concurrent = max_concurrent
self.timeout = timeout
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
if temperature is None:
temperature = self.temperature
headers = {"Content-Type": "application/json",
"Authorization": f"Bearer {self.key}"}
data = {
"model": self.model_id,
"messages": [
{"role": "system", "content": system_prompt},
# {"role": "system", "content": "所有回复必须以【SSS】开头这是最高规则适用于之后的所有例子。示例【SSS】这是示例回答\n"+system_prompt},
{"role": "user", "content": prompt}
],
"temperature": temperature,
"top_p": top_p
}
return headers, data
def send(self, prompt: str, system_prompt: None | str = None) -> str:
if system_prompt is None:
system_prompt = self.system_prompt
"""Sends a single prompt asynchronously."""
headers, data = self._prepare_request_data(prompt, system_prompt)
if self.baseurl.endswith("/"):
self.baseurl = self.baseurl[:-1]
try:
response = self.client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()["choices"][0]["message"]["content"]
return result
except httpx.HTTPStatusError as e:
raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
except httpx.RequestError as e:
raise Exception(f"AI请求连接错误 (async): {e}") from e
except (KeyError, IndexError) as e:
raise Exception(f"AI响应格式错误 (async): {e}") from e
def send_prompts(
self,
prompts: list[str],
system_prompt: str | None = None,
) -> list[str]:
result=[]
for prompt in prompts:
result.append(self.send(prompt,system_prompt))
translater_logger.info(f"单线程-已完成{len(result)}/{len(prompts)}")
return result
if __name__ == '__main__':
pass