检查agent的result是否满足segments要求
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Literal
|
from typing import Literal, Callable, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -67,6 +68,9 @@ class PromptsCounter:
|
|||||||
|
|
||||||
TIMEOUT = 600
|
TIMEOUT = 600
|
||||||
|
|
||||||
|
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
||||||
|
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
_think_factory = {
|
_think_factory = {
|
||||||
@@ -129,7 +133,9 @@ class Agent:
|
|||||||
self._add_thinking_mode(data)
|
self._add_thinking_mode(data)
|
||||||
return headers, data
|
return headers, data
|
||||||
|
|
||||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||||
|
result_handler: ResultHandlerType = None,
|
||||||
|
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
if prompt.strip() == "":
|
if prompt.strip() == "":
|
||||||
@@ -145,12 +151,12 @@ class Agent:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()["choices"][0]["message"]["content"]
|
result = response.json()["choices"][0]["message"]["content"]
|
||||||
return result
|
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
||||||
print(f"prompt:\n{prompt}")
|
print(f"prompt:\n{prompt}")
|
||||||
self.total_error_counter.add()
|
self.total_error_counter.add()
|
||||||
return prompt
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}")
|
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}")
|
||||||
except (KeyError, IndexError) as e:
|
except (KeyError, IndexError) as e:
|
||||||
@@ -158,20 +164,23 @@ class Agent:
|
|||||||
# 如果没有正常获取结果则重试
|
# 如果没有正常获取结果则重试
|
||||||
if retry and retry_count < MAX_RETRY_COUNT:
|
if retry and retry_count < MAX_RETRY_COUNT:
|
||||||
if self.total_error_counter.add():
|
if self.total_error_counter.add():
|
||||||
return prompt
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1)
|
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||||
|
result_handler=result_handler)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"达到重试次数上限")
|
self.logger.error(f"达到重试次数上限")
|
||||||
return prompt
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
|
|
||||||
async def send_prompts_async(
|
async def send_prompts_async(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_concurrent: int | None = None # 新增参数,默认并发数为5
|
max_concurrent: int | None = None, # 新增参数,默认并发数为5
|
||||||
) -> list[str]:
|
result_handler: ResultHandlerType = None,
|
||||||
|
error_result_handler: ErrorResultHandlerType = None
|
||||||
|
) -> list[Any]:
|
||||||
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
|
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
|
||||||
total = len(prompts)
|
total = len(prompts)
|
||||||
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
||||||
@@ -186,6 +195,8 @@ class Agent:
|
|||||||
result = await self.send_async(
|
result = await self.send_async(
|
||||||
prompt=p_text,
|
prompt=p_text,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
result_handler=result_handler,
|
||||||
|
error_result_handler=error_result_handler,
|
||||||
)
|
)
|
||||||
nonlocal count
|
nonlocal count
|
||||||
count += 1
|
count += 1
|
||||||
@@ -199,7 +210,8 @@ class Agent:
|
|||||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||||
|
result_handler=None, error_result_handler=None) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
if prompt.strip() == "":
|
if prompt.strip() == "":
|
||||||
@@ -214,12 +226,12 @@ class Agent:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()["choices"][0]["message"]["content"]
|
result = response.json()["choices"][0]["message"]["content"]
|
||||||
return result
|
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
||||||
print(f"prompt:\n{prompt}")
|
print(f"prompt:\n{prompt}")
|
||||||
self.total_error_counter.add()
|
self.total_error_counter.add()
|
||||||
return prompt
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
||||||
except (KeyError, IndexError) as e:
|
except (KeyError, IndexError) as e:
|
||||||
@@ -227,16 +239,19 @@ class Agent:
|
|||||||
# 如果没有正常获取结果则重试
|
# 如果没有正常获取结果则重试
|
||||||
if retry and retry_count < MAX_RETRY_COUNT:
|
if retry and retry_count < MAX_RETRY_COUNT:
|
||||||
if self.total_error_counter.add():
|
if self.total_error_counter.add():
|
||||||
return prompt
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1)
|
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||||
|
result_handler=result_handler)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"达到重试次数上限")
|
self.logger.error(f"达到重试次数上限")
|
||||||
return prompt
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
|
|
||||||
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter) -> str:
|
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler,
|
||||||
result = self.send(prompt, system_prompt)
|
error_result_handler) -> Any:
|
||||||
|
result = self.send(prompt, system_prompt, result_handler=result_handler,
|
||||||
|
error_result_handler=error_result_handler)
|
||||||
count.add()
|
count.add()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -244,14 +259,23 @@ class Agent:
|
|||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
) -> list[str]:
|
result_handler: ResultHandlerType = None,
|
||||||
|
error_result_handler: ErrorResultHandlerType = None
|
||||||
|
) -> list[Any]:
|
||||||
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
||||||
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
|
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
|
||||||
system_prompts = [system_prompt] * len(prompts)
|
|
||||||
counts = [PromptsCounter(len(prompts), self.logger)] * len(prompts)
|
# 创建单个计数器实例
|
||||||
|
counter = PromptsCounter(len(prompts), self.logger)
|
||||||
|
|
||||||
|
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
|
||||||
|
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||||
|
counters = itertools.repeat(counter, len(prompts))
|
||||||
|
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||||
|
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||||
output_list = []
|
output_list = []
|
||||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||||
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counts)
|
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers)
|
||||||
output_list = list(results_iterator)
|
output_list = list(results_iterator)
|
||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
from json_repair import json_repair
|
from json_repair import json_repair
|
||||||
|
|
||||||
@@ -43,23 +44,38 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
|||||||
if config.custom_prompt:
|
if config.custom_prompt:
|
||||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||||
|
|
||||||
|
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||||
|
try:
|
||||||
|
result = json_repair.loads(result)
|
||||||
|
except:
|
||||||
|
logger.error("结果不能正确解析")
|
||||||
|
return self._error_result_handler(origin_prompt, logger)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _error_result_handler(self, origin_prompt: str, logger: Logger):
|
||||||
|
try:
|
||||||
|
return json_repair.loads(origin_prompt)
|
||||||
|
except:
|
||||||
|
logger.error("prompt不是json格式")
|
||||||
|
return origin_prompt
|
||||||
|
|
||||||
def send_segments(self, segments: list[str], chunk_size: int):
|
def send_segments(self, segments: list[str], chunk_size: int):
|
||||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||||
translated_chunks = super().send_prompts(prompts=prompts)
|
translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
|
||||||
|
error_result_handler=self._error_result_handler)
|
||||||
indexed_translated = indexed_originals.copy()
|
indexed_translated = indexed_originals.copy()
|
||||||
for chunk_str in translated_chunks:
|
for chunk in translated_chunks:
|
||||||
try:
|
try:
|
||||||
translated_part = json_repair.loads(chunk_str)
|
for key, val in chunk.items():
|
||||||
for key, val in translated_part.items():
|
|
||||||
if key in indexed_translated:
|
if key in indexed_translated:
|
||||||
indexed_translated[key] = val
|
indexed_translated[key] = val
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.info(f"send_segments错误:{e.__repr__()}")
|
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||||
|
|
||||||
# 初始化结果列表
|
# 初始化结果列表
|
||||||
result = []
|
result = []
|
||||||
@@ -79,23 +95,23 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
|||||||
|
|
||||||
# todo:增加协程粒度
|
# todo:增加协程粒度
|
||||||
async def send_segments_async(self, segments: list[str], chunk_size: int):
|
async def send_segments_async(self, segments: list[str], chunk_size: int):
|
||||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks,segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||||
|
chunk_size)
|
||||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||||
translated_chunks = await super().send_prompts_async(prompts=prompts)
|
translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
|
||||||
|
error_result_handler=self._error_result_handler)
|
||||||
indexed_translated = indexed_originals.copy()
|
indexed_translated = indexed_originals.copy()
|
||||||
for chunk_str in translated_chunks:
|
for chunk in translated_chunks:
|
||||||
try:
|
try:
|
||||||
translated_part = json_repair.loads(chunk_str)
|
for key, val in chunk.items():
|
||||||
for key, val in translated_part.items():
|
|
||||||
if key in indexed_translated:
|
if key in indexed_translated:
|
||||||
indexed_translated[key] = val
|
indexed_translated[key] = val
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.info(f"send_segments错误:{e.__repr__()}")
|
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||||
|
|
||||||
|
|
||||||
# 初始化结果列表
|
# 初始化结果列表
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
Reference in New Issue
Block a user