From 9ba8f142ce5ab6c1aaf87196b6a17b24ec3f539c Mon Sep 17 00:00:00 2001 From: xunbu Date: Tue, 13 May 2025 16:57:51 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BE=9B=E4=BA=86=E5=A4=9A=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E5=92=8C=E5=90=8C=E6=AD=A5=E4=B8=A4=E7=A7=8D=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E6=96=B9=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/workspace.xml | 17 +-- docutranslate/agents/__init__.py | 2 +- .../agents/{agent.py => agent_async.py} | 10 +- docutranslate/agents/agent_sync.py | 89 ++++++++++++++ docutranslate/agents/agent_thread.py | 109 ++++++++++++++++++ docutranslate/agents/markdown_agent.py | 2 +- pyproject.toml | 2 +- 7 files changed, 216 insertions(+), 15 deletions(-) rename docutranslate/agents/{agent.py => agent_async.py} (93%) create mode 100644 docutranslate/agents/agent_sync.py create mode 100644 docutranslate/agents/agent_thread.py diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 7dfa115..333b42a 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -5,9 +5,12 @@ + + + + - + - @@ -549,8 +552,8 @@ - + @@ -626,14 +629,14 @@ - + - - + + diff --git a/docutranslate/agents/__init__.py b/docutranslate/agents/__init__.py index d918a04..c34968a 100644 --- a/docutranslate/agents/__init__.py +++ b/docutranslate/agents/__init__.py @@ -1,2 +1,2 @@ -from .agent import Agent, AgentArgs +from .agent_async import Agent, AgentArgs from .markdown_agent import MDRefineAgent, MDTranslateAgent diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent_async.py similarity index 93% rename from docutranslate/agents/agent.py rename to docutranslate/agents/agent_async.py index 31f4b6d..73a9e3f 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent_async.py @@ -1,6 +1,4 @@ import asyncio - -# import re from typing import TypedDict import httpx @@ -22,7 +20,8 @@ TIMEOUT = 500 class Agent: - def __init__(self, baseurl:str="", key:str="xx", model_id:str="", system_prompt:str="", temperature=0.7, max_concurrent=6,timeout:int=TIMEOUT): + def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7, + max_concurrent=6, timeout: int = TIMEOUT): self.baseurl = baseurl.strip() self.key = key.strip() self.model_id = model_id.strip() @@ -30,7 +29,7 @@ class Agent: self.temperature = temperature self.client_async = httpx.AsyncClient() self.max_concurrent = max_concurrent - self.timeout=timeout + self.timeout = timeout def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): if temperature is None: @@ -90,6 +89,7 @@ class Agent: count = 0 semaphore = asyncio.Semaphore(max_concurrent) tasks = [] + # 辅助协程,用于包装 self.send_async 并使用信号量 async def send_with_semaphore(p_text: str): async with semaphore: # 在进入代码块前获取信号量,退出时释放 @@ -99,7 +99,7 @@ class Agent: ) nonlocal count count += 1 - translater_logger.info(f"进行到{count}/{total}") + translater_logger.info(f"协程-已完成{count}/{total}") return result for p_text in prompts: diff --git a/docutranslate/agents/agent_sync.py b/docutranslate/agents/agent_sync.py new file mode 100644 index 0000000..72ef583 --- /dev/null +++ b/docutranslate/agents/agent_sync.py @@ -0,0 +1,89 @@ +from typing import TypedDict +from docutranslate.logger import translater_logger +import httpx + + +class AgentArgs(TypedDict, total=False): + baseurl: str + key: str + model_id: str + system_prompt: str + temperature: float + max_concurrent: int + timeout: int + + +TIMEOUT = 500 + + + +class Agent: + def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7, + max_concurrent=6, timeout: int = TIMEOUT): + self.baseurl = baseurl.strip() + self.key = key.strip() + self.model_id = model_id.strip() + self.system_prompt = system_prompt + self.temperature = temperature + self.client = httpx.Client() + self.max_concurrent = max_concurrent + self.timeout = timeout + + def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): + if temperature is None: + temperature = self.temperature + headers = {"Content-Type": "application/json", + "Authorization": f"Bearer {self.key}"} + data = { + "model": self.model_id, + "messages": [ + {"role": "system", "content": system_prompt}, + # {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt}, + {"role": "user", "content": prompt} + ], + "temperature": temperature, + "top_p": top_p + } + return headers, data + + def send(self, prompt: str, system_prompt: None | str = None) -> str: + if system_prompt is None: + system_prompt = self.system_prompt + + """Sends a single prompt asynchronously.""" + headers, data = self._prepare_request_data(prompt, system_prompt) + if self.baseurl.endswith("/"): + self.baseurl = self.baseurl[:-1] + try: + response = self.client.post( + f"{self.baseurl}/chat/completions", + json=data, + headers=headers, + timeout=self.timeout + ) + response.raise_for_status() + result = response.json()["choices"][0]["message"]["content"] + return result + except httpx.HTTPStatusError as e: + raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e + except httpx.RequestError as e: + raise Exception(f"AI请求连接错误 (async): {e}") from e + except (KeyError, IndexError) as e: + raise Exception(f"AI响应格式错误 (async): {e}") from e + + + def send_prompts( + self, + prompts: list[str], + system_prompt: str | None = None, + ) -> list[str]: + result=[] + for prompt in prompts: + result.append(self.send(prompt,system_prompt)) + translater_logger.info(f"单线程-已完成{len(result)}/{len(prompts)}") + return result + + + +if __name__ == '__main__': + pass diff --git a/docutranslate/agents/agent_thread.py b/docutranslate/agents/agent_thread.py new file mode 100644 index 0000000..d8ed1df --- /dev/null +++ b/docutranslate/agents/agent_thread.py @@ -0,0 +1,109 @@ +from concurrent.futures import ThreadPoolExecutor +from threading import Lock +from typing import TypedDict +from docutranslate.logger import translater_logger +import httpx + + +class AgentArgs(TypedDict, total=False): + baseurl: str + key: str + model_id: str + system_prompt: str + temperature: float + max_concurrent: int + timeout: int + + +TIMEOUT = 500 + + +class PromptsCount(): + def __init__(self,max:int): + self.lock=Lock() + self.count=0 + self.max=max + + def add(self): + self.lock.acquire() + self.count+=1 + translater_logger.info(f"多线程-已完成:{self.count}/{self.max}") + self.lock.release() + + +class Agent: + def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7, + max_concurrent=6, timeout: int = TIMEOUT): + self.baseurl = baseurl.strip() + self.key = key.strip() + self.model_id = model_id.strip() + self.system_prompt = system_prompt + self.temperature = temperature + self.client = httpx.Client() + self.max_concurrent = max_concurrent + self.timeout = timeout + + def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9): + if temperature is None: + temperature = self.temperature + headers = {"Content-Type": "application/json", + "Authorization": f"Bearer {self.key}"} + data = { + "model": self.model_id, + "messages": [ + {"role": "system", "content": system_prompt}, + # {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt}, + {"role": "user", "content": prompt} + ], + "temperature": temperature, + "top_p": top_p + } + return headers, data + + def send(self, prompt: str, system_prompt: None | str = None) -> str: + if system_prompt is None: + system_prompt = self.system_prompt + + """Sends a single prompt asynchronously.""" + headers, data = self._prepare_request_data(prompt, system_prompt) + if self.baseurl.endswith("/"): + self.baseurl = self.baseurl[:-1] + try: + response = self.client.post( + f"{self.baseurl}/chat/completions", + json=data, + headers=headers, + timeout=self.timeout + ) + response.raise_for_status() + result = response.json()["choices"][0]["message"]["content"] + return result + except httpx.HTTPStatusError as e: + raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e + except httpx.RequestError as e: + raise Exception(f"AI请求连接错误 (async): {e}") from e + except (KeyError, IndexError) as e: + raise Exception(f"AI响应格式错误 (async): {e}") from e + + def _send_prompt_count(self,prompt: str, system_prompt:None | str,count:PromptsCount)->str: + result=self.send(prompt,system_prompt) + count.add() + return result + + + def send_prompts( + self, + prompts: list[str], + system_prompt: str | None = None, + ) -> list[str]: + system_prompts = [system_prompt] * len(prompts) + counts=[PromptsCount(len(prompts))]* len(prompts) + output_list = [] + with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: + results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts,counts) + output_list = list(results_iterator) + return output_list + + +if __name__ == '__main__': + pass diff --git a/docutranslate/agents/markdown_agent.py b/docutranslate/agents/markdown_agent.py index 7846f4d..db27216 100644 --- a/docutranslate/agents/markdown_agent.py +++ b/docutranslate/agents/markdown_agent.py @@ -1,6 +1,6 @@ from typing import Unpack -from .agent import Agent, AgentArgs +from .agent_async import Agent, AgentArgs class MDRefineAgent(Agent): diff --git a/pyproject.toml b/pyproject.toml index 92f667a..d9bc677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "docutranslate" -version = "0.2.2.post1" +version = "0.2.3" description = "文件翻译工具" readme = "README.md" requires-python = ">=3.10"