提供了多线程和同步两种实现方案
This commit is contained in:
17
.idea/workspace.xml
generated
17
.idea/workspace.xml
generated
@@ -5,9 +5,12 @@
|
||||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="6b18b44a-df57-4212-a857-9e291ebe5dd2" name="更改" comment="">
|
||||
<change afterPath="$PROJECT_DIR$/docutranslate/agents/agent_sync.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/docutranslate/agents/agent_thread.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/docutranslate/agents/__init__.py" beforeDir="false" afterPath="$PROJECT_DIR$/docutranslate/agents/__init__.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/docutranslate/agents/agent.py" beforeDir="false" afterPath="$PROJECT_DIR$/docutranslate/agents/agent_async.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/docutranslate/agents/markdown_agent.py" beforeDir="false" afterPath="$PROJECT_DIR$/docutranslate/agents/markdown_agent.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/docutranslate/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/docutranslate/app.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/pyproject.toml" beforeDir="false" afterPath="$PROJECT_DIR$/pyproject.toml" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
@@ -70,7 +73,7 @@
|
||||
"RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager": "true",
|
||||
"RunOnceActivity.git.unshallow": "true",
|
||||
"git-widget-placeholder": "main",
|
||||
"last_opened_file_path": "C:/Users/jxgm/Desktop/FileTranslate/dist/app",
|
||||
"last_opened_file_path": "C:/Users/jxgm/Desktop/FileTranslate/docutranslate/agents",
|
||||
"node.js.detected.package.eslint": "true",
|
||||
"node.js.detected.package.tslint": "true",
|
||||
"node.js.selected.package.eslint": "(autodetect)",
|
||||
@@ -82,11 +85,11 @@
|
||||
}]]></component>
|
||||
<component name="RecentsManager">
|
||||
<key name="CopyFile.RECENT_KEYS">
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\docutranslate\agents" />
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\dist\app" />
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\tests\files" />
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\tests" />
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\docutranslate" />
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\tests\备份" />
|
||||
</key>
|
||||
<key name="MoveFile.RECENT_KEYS">
|
||||
<recent name="C:\Users\jxgm\Desktop\FileTranslate\dist\app" />
|
||||
@@ -549,8 +552,8 @@
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.app" />
|
||||
<item itemvalue="Python.app2" />
|
||||
<item itemvalue="Python.test" />
|
||||
<item itemvalue="Python.app2" />
|
||||
<item itemvalue="JavaScript 调试.regex.md_中文.html" />
|
||||
<item itemvalue="Python.切分测试" />
|
||||
</list>
|
||||
@@ -626,14 +629,14 @@
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/filetranslate$agent_utils.coverage" NAME="agent_utils 覆盖结果" MODIFIED="1746708534311" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/docutranslate/utils" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$.coverage" NAME="切分测试 覆盖结果" MODIFIED="1747057615595" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tests" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$test.coverage" NAME="test 覆盖结果" MODIFIED="1747106671303" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tests" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$test.coverage" NAME="test 覆盖结果" MODIFIED="1747125597841" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tests" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$convert.coverage" NAME="convert 覆盖结果" MODIFIED="1746963490689" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/docutranslate/utils" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$test1.coverage" NAME="test1 覆盖结果" MODIFIED="1746936018440" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tests" />
|
||||
<SUITE FILE_PATH="coverage/PDFtranslate$PDFtranslater__1_.coverage" NAME="PDFtranslater (1) 覆盖结果" MODIFIED="1746633258205" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/pdftranslate_packages" />
|
||||
<SUITE FILE_PATH="coverage/PDFtranslate$convert.coverage" NAME="convert 覆盖结果" MODIFIED="1746596984213" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/pdftranslate_packages/utils" />
|
||||
<SUITE FILE_PATH="coverage/PDFtranslate$agent_utils.coverage" NAME="agent_utils 覆盖结果" MODIFIED="1746617703678" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/pdftranslate_packages/utils" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$app2.coverage" NAME="app2 覆盖结果" MODIFIED="1747105074010" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/docutranslate" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$app.coverage" NAME="app 覆盖结果" MODIFIED="1747107758688" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/docutranslate" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$app2.coverage" NAME="app2 覆盖结果" MODIFIED="1747108180309" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/docutranslate" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$app.coverage" NAME="app 覆盖结果" MODIFIED="1747126209674" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/docutranslate" />
|
||||
<SUITE FILE_PATH="coverage/PDFtranslate$markdown_splitter.coverage" NAME="markdown_splitter 覆盖结果" MODIFIED="1746599883603" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/pdftranslate_packages/utils" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$test4.coverage" NAME="test4 覆盖结果" MODIFIED="1746887036353" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tests" />
|
||||
<SUITE FILE_PATH="coverage/filetranslate$test3.coverage" NAME="test3 覆盖结果" MODIFIED="1746884110572" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tests" />
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .agent import Agent, AgentArgs
|
||||
from .agent_async import Agent, AgentArgs
|
||||
from .markdown_agent import MDRefineAgent, MDTranslateAgent
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
|
||||
# import re
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
@@ -22,7 +20,8 @@ TIMEOUT = 500
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, baseurl:str="", key:str="xx", model_id:str="", system_prompt:str="", temperature=0.7, max_concurrent=6,timeout:int=TIMEOUT):
|
||||
def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
|
||||
max_concurrent=6, timeout: int = TIMEOUT):
|
||||
self.baseurl = baseurl.strip()
|
||||
self.key = key.strip()
|
||||
self.model_id = model_id.strip()
|
||||
@@ -90,6 +89,7 @@ class Agent:
|
||||
count = 0
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
tasks = []
|
||||
|
||||
# 辅助协程,用于包装 self.send_async 并使用信号量
|
||||
async def send_with_semaphore(p_text: str):
|
||||
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
||||
@@ -99,7 +99,7 @@ class Agent:
|
||||
)
|
||||
nonlocal count
|
||||
count += 1
|
||||
translater_logger.info(f"进行到{count}/{total}")
|
||||
translater_logger.info(f"协程-已完成{count}/{total}")
|
||||
return result
|
||||
|
||||
for p_text in prompts:
|
||||
89
docutranslate/agents/agent_sync.py
Normal file
89
docutranslate/agents/agent_sync.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import TypedDict
|
||||
from docutranslate.logger import translater_logger
|
||||
import httpx
|
||||
|
||||
|
||||
class AgentArgs(TypedDict, total=False):
|
||||
baseurl: str
|
||||
key: str
|
||||
model_id: str
|
||||
system_prompt: str
|
||||
temperature: float
|
||||
max_concurrent: int
|
||||
timeout: int
|
||||
|
||||
|
||||
TIMEOUT = 500
|
||||
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
|
||||
max_concurrent=6, timeout: int = TIMEOUT):
|
||||
self.baseurl = baseurl.strip()
|
||||
self.key = key.strip()
|
||||
self.model_id = model_id.strip()
|
||||
self.system_prompt = system_prompt
|
||||
self.temperature = temperature
|
||||
self.client = httpx.Client()
|
||||
self.max_concurrent = max_concurrent
|
||||
self.timeout = timeout
|
||||
|
||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
headers = {"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}"}
|
||||
data = {
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
return headers, data
|
||||
|
||||
def send(self, prompt: str, system_prompt: None | str = None) -> str:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
|
||||
"""Sends a single prompt asynchronously."""
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
if self.baseurl.endswith("/"):
|
||||
self.baseurl = self.baseurl[:-1]
|
||||
try:
|
||||
response = self.client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise Exception(f"AI请求连接错误 (async): {e}") from e
|
||||
except (KeyError, IndexError) as e:
|
||||
raise Exception(f"AI响应格式错误 (async): {e}") from e
|
||||
|
||||
|
||||
def send_prompts(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
) -> list[str]:
|
||||
result=[]
|
||||
for prompt in prompts:
|
||||
result.append(self.send(prompt,system_prompt))
|
||||
translater_logger.info(f"单线程-已完成{len(result)}/{len(prompts)}")
|
||||
return result
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
109
docutranslate/agents/agent_thread.py
Normal file
109
docutranslate/agents/agent_thread.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import TypedDict
|
||||
from docutranslate.logger import translater_logger
|
||||
import httpx
|
||||
|
||||
|
||||
class AgentArgs(TypedDict, total=False):
|
||||
baseurl: str
|
||||
key: str
|
||||
model_id: str
|
||||
system_prompt: str
|
||||
temperature: float
|
||||
max_concurrent: int
|
||||
timeout: int
|
||||
|
||||
|
||||
TIMEOUT = 500
|
||||
|
||||
|
||||
class PromptsCount():
|
||||
def __init__(self,max:int):
|
||||
self.lock=Lock()
|
||||
self.count=0
|
||||
self.max=max
|
||||
|
||||
def add(self):
|
||||
self.lock.acquire()
|
||||
self.count+=1
|
||||
translater_logger.info(f"多线程-已完成:{self.count}/{self.max}")
|
||||
self.lock.release()
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, baseurl: str = "", key: str = "xx", model_id: str = "", system_prompt: str = "", temperature=0.7,
|
||||
max_concurrent=6, timeout: int = TIMEOUT):
|
||||
self.baseurl = baseurl.strip()
|
||||
self.key = key.strip()
|
||||
self.model_id = model_id.strip()
|
||||
self.system_prompt = system_prompt
|
||||
self.temperature = temperature
|
||||
self.client = httpx.Client()
|
||||
self.max_concurrent = max_concurrent
|
||||
self.timeout = timeout
|
||||
|
||||
def _prepare_request_data(self, prompt: str, system_prompt: str, temperature=None, top_p=0.9):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
headers = {"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}"}
|
||||
data = {
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "system", "content": "所有回复必须以【SSS】开头(这是最高规则,适用于之后的所有例子)。示例:【SSS】这是示例回答\n"+system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
return headers, data
|
||||
|
||||
def send(self, prompt: str, system_prompt: None | str = None) -> str:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
|
||||
"""Sends a single prompt asynchronously."""
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
if self.baseurl.endswith("/"):
|
||||
self.baseurl = self.baseurl[:-1]
|
||||
try:
|
||||
response = self.client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise Exception(f"AI请求连接错误 (async): {e}") from e
|
||||
except (KeyError, IndexError) as e:
|
||||
raise Exception(f"AI响应格式错误 (async): {e}") from e
|
||||
|
||||
def _send_prompt_count(self,prompt: str, system_prompt:None | str,count:PromptsCount)->str:
|
||||
result=self.send(prompt,system_prompt)
|
||||
count.add()
|
||||
return result
|
||||
|
||||
|
||||
def send_prompts(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
) -> list[str]:
|
||||
system_prompts = [system_prompt] * len(prompts)
|
||||
counts=[PromptsCount(len(prompts))]* len(prompts)
|
||||
output_list = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts,counts)
|
||||
output_list = list(results_iterator)
|
||||
return output_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Unpack
|
||||
|
||||
from .agent import Agent, AgentArgs
|
||||
from .agent_async import Agent, AgentArgs
|
||||
|
||||
|
||||
class MDRefineAgent(Agent):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "docutranslate"
|
||||
version = "0.2.2.post1"
|
||||
version = "0.2.3"
|
||||
description = "文件翻译工具"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user