diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 7dfa115..333b42a 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -5,9 +5,12 @@
+
+
+
+
-
@@ -70,7 +73,7 @@
"RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager": "true",
"RunOnceActivity.git.unshallow": "true",
"git-widget-placeholder": "main",
- "last_opened_file_path": "C:/Users/jxgm/Desktop/FileTranslate/dist/app",
+ "last_opened_file_path": "C:/Users/jxgm/Desktop/FileTranslate/docutranslate/agents",
"node.js.detected.package.eslint": "true",
"node.js.detected.package.tslint": "true",
"node.js.selected.package.eslint": "(autodetect)",
@@ -82,11 +85,11 @@
}]]>
+
-
@@ -549,8 +552,8 @@
-
+
@@ -626,14 +629,14 @@
-
+
-
-
+
+
diff --git a/docutranslate/agents/__init__.py b/docutranslate/agents/__init__.py
index d918a04..c34968a 100644
--- a/docutranslate/agents/__init__.py
+++ b/docutranslate/agents/__init__.py
@@ -1,2 +1,2 @@
-from .agent import Agent, AgentArgs
+from .agent_async import Agent, AgentArgs
from .markdown_agent import MDRefineAgent, MDTranslateAgent
diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent_async.py
similarity index 93%
rename from docutranslate/agents/agent.py
rename to docutranslate/agents/agent_async.py
index 31f4b6d..73a9e3f 100644
--- a/docutranslate/agents/agent.py
+++ b/docutranslate/agents/agent_async.py
@@ -1,6 +1,4 @@
import asyncio
-
-# import re
from typing import TypedDict
import httpx
@@ -22,7 +20,8 @@ TIMEOUT = 500
class Agent:
- def __init__(self, baseurl:str="", key:str="xx", model_id:str="", system_prompt:str="", temperature=0.7, max_concurrent=6,timeout:int=TIMEOUT):
+ def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
+ max_concurrent=6, timeout: int = TIMEOUT):
self.baseurl = baseurl.strip()
self.key = key.strip()
self.model_id = model_id.strip()
@@ -30,7 +29,7 @@ class Agent:
self.temperature = temperature
self.client_async = httpx.AsyncClient()
self.max_concurrent = max_concurrent
- self.timeout=timeout
+ self.timeout = timeout
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
if temperature is None:
@@ -90,6 +89,7 @@ class Agent:
count = 0
semaphore = asyncio.Semaphore(max_concurrent)
tasks = []
+
# 辅助协程,用于包装 self.send_async 并使用信号量
async def send_with_semaphore(p_text: str):
async with semaphore: # 在进入代码块前获取信号量,退出时释放
@@ -99,7 +99,7 @@ class Agent:
)
nonlocal count
count += 1
- translater_logger.info(f"进行到{count}/{total}")
+ translater_logger.info(f"协程-已完成{count}/{total}")
return result
for p_text in prompts:
diff --git a/docutranslate/agents/agent_sync.py b/docutranslate/agents/agent_sync.py
new file mode 100644
index 0000000..72ef583
--- /dev/null
+++ b/docutranslate/agents/agent_sync.py
@@ -0,0 +1,89 @@
+from typing import TypedDict
+from docutranslate.logger import translater_logger
+import httpx
+
+
+class AgentArgs(TypedDict, total=False):
+ baseurl: str
+ key: str
+ model_id: str
+ system_prompt: str
+ temperature: float
+ max_concurrent: int
+ timeout: int
+
+
+TIMEOUT = 500
+
+
+
+class Agent:
+ def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
+ max_concurrent=6, timeout: int = TIMEOUT):
+ self.baseurl = baseurl.strip()
+ self.key = key.strip()
+ self.model_id = model_id.strip()
+ self.system_prompt = system_prompt
+ self.temperature = temperature
+ self.client = httpx.Client()
+ self.max_concurrent = max_concurrent
+ self.timeout = timeout
+
+ def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
+ if temperature is None:
+ temperature = self.temperature
+ headers = {"Content-Type": "application/json",
+ "Authorization": f"Bearer {self.key}"}
+ data = {
+ "model": self.model_id,
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ # {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt},
+ {"role": "user", "content": prompt}
+ ],
+ "temperature": temperature,
+ "top_p": top_p
+ }
+ return headers, data
+
+ def send(self, prompt: str, system_prompt: None | str = None) -> str:
+ if system_prompt is None:
+ system_prompt = self.system_prompt
+
+ """Sends a single prompt asynchronously."""
+ headers, data = self._prepare_request_data(prompt, system_prompt)
+ if self.baseurl.endswith("/"):
+ self.baseurl = self.baseurl[:-1]
+ try:
+ response = self.client.post(
+ f"{self.baseurl}/chat/completions",
+ json=data,
+ headers=headers,
+ timeout=self.timeout
+ )
+ response.raise_for_status()
+ result = response.json()["choices"][0]["message"]["content"]
+ return result
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
+ except httpx.RequestError as e:
+ raise Exception(f"AI请求连接错误 (async): {e}") from e
+ except (KeyError, IndexError) as e:
+ raise Exception(f"AI响应格式错误 (async): {e}") from e
+
+
+ def send_prompts(
+ self,
+ prompts: list[str],
+ system_prompt: str | None = None,
+ ) -> list[str]:
+ result=[]
+ for prompt in prompts:
+ result.append(self.send(prompt,system_prompt))
+ translater_logger.info(f"单线程-已完成{len(result)}/{len(prompts)}")
+ return result
+
+
+
+if __name__ == '__main__':
+ pass
diff --git a/docutranslate/agents/agent_thread.py b/docutranslate/agents/agent_thread.py
new file mode 100644
index 0000000..d8ed1df
--- /dev/null
+++ b/docutranslate/agents/agent_thread.py
@@ -0,0 +1,109 @@
+from concurrent.futures import ThreadPoolExecutor
+from threading import Lock
+from typing import TypedDict
+from docutranslate.logger import translater_logger
+import httpx
+
+
+class AgentArgs(TypedDict, total=False):
+ baseurl: str
+ key: str
+ model_id: str
+ system_prompt: str
+ temperature: float
+ max_concurrent: int
+ timeout: int
+
+
+TIMEOUT = 500
+
+
+class PromptsCount():
+ def __init__(self,max:int):
+ self.lock=Lock()
+ self.count=0
+ self.max=max
+
+ def add(self):
+ self.lock.acquire()
+ self.count+=1
+ translater_logger.info(f"多线程-已完成:{self.count}/{self.max}")
+ self.lock.release()
+
+
+class Agent:
+ def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
+ max_concurrent=6, timeout: int = TIMEOUT):
+ self.baseurl = baseurl.strip()
+ self.key = key.strip()
+ self.model_id = model_id.strip()
+ self.system_prompt = system_prompt
+ self.temperature = temperature
+ self.client = httpx.Client()
+ self.max_concurrent = max_concurrent
+ self.timeout = timeout
+
+ def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
+ if temperature is None:
+ temperature = self.temperature
+ headers = {"Content-Type": "application/json",
+ "Authorization": f"Bearer {self.key}"}
+ data = {
+ "model": self.model_id,
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ # {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt},
+ {"role": "user", "content": prompt}
+ ],
+ "temperature": temperature,
+ "top_p": top_p
+ }
+ return headers, data
+
+ def send(self, prompt: str, system_prompt: None | str = None) -> str:
+ if system_prompt is None:
+ system_prompt = self.system_prompt
+
+ """Sends a single prompt asynchronously."""
+ headers, data = self._prepare_request_data(prompt, system_prompt)
+ if self.baseurl.endswith("/"):
+ self.baseurl = self.baseurl[:-1]
+ try:
+ response = self.client.post(
+ f"{self.baseurl}/chat/completions",
+ json=data,
+ headers=headers,
+ timeout=self.timeout
+ )
+ response.raise_for_status()
+ result = response.json()["choices"][0]["message"]["content"]
+ return result
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
+ except httpx.RequestError as e:
+ raise Exception(f"AI请求连接错误 (async): {e}") from e
+ except (KeyError, IndexError) as e:
+ raise Exception(f"AI响应格式错误 (async): {e}") from e
+
+ def _send_prompt_count(self,prompt: str, system_prompt:None | str,count:PromptsCount)->str:
+ result=self.send(prompt,system_prompt)
+ count.add()
+ return result
+
+
+ def send_prompts(
+ self,
+ prompts: list[str],
+ system_prompt: str | None = None,
+ ) -> list[str]:
+ system_prompts = [system_prompt] * len(prompts)
+ counts=[PromptsCount(len(prompts))]* len(prompts)
+ output_list = []
+ with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
+ results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts,counts)
+ output_list = list(results_iterator)
+ return output_list
+
+
+if __name__ == '__main__':
+ pass
diff --git a/docutranslate/agents/markdown_agent.py b/docutranslate/agents/markdown_agent.py
index 7846f4d..db27216 100644
--- a/docutranslate/agents/markdown_agent.py
+++ b/docutranslate/agents/markdown_agent.py
@@ -1,6 +1,6 @@
from typing import Unpack
-from .agent import Agent, AgentArgs
+from .agent_async import Agent, AgentArgs
class MDRefineAgent(Agent):
diff --git a/pyproject.toml b/pyproject.toml
index 92f667a..d9bc677 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "docutranslate"
-version = "0.2.2.post1"
+version = "0.2.3"
description = "文件翻译工具"
readme = "README.md"
requires-python = ">=3.10"