检查agent的result是否满足segments要求
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Literal
|
||||
from typing import Literal, Callable, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@@ -67,6 +68,9 @@ class PromptsCounter:
|
||||
|
||||
TIMEOUT = 600
|
||||
|
||||
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
||||
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
||||
|
||||
|
||||
class Agent:
|
||||
_think_factory = {
|
||||
@@ -129,7 +133,9 @@ class Agent:
|
||||
self._add_thinking_mode(data)
|
||||
return headers, data
|
||||
|
||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if prompt.strip() == "":
|
||||
@@ -145,12 +151,12 @@ class Agent:
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
return result
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
||||
print(f"prompt:\n{prompt}")
|
||||
self.total_error_counter.add()
|
||||
return prompt
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
except httpx.RequestError as e:
|
||||
self.logger.warning(f"AI请求连接错误 (async): {repr(e)}")
|
||||
except (KeyError, IndexError) as e:
|
||||
@@ -158,20 +164,23 @@ class Agent:
|
||||
# 如果没有正常获取结果则重试
|
||||
if retry and retry_count < MAX_RETRY_COUNT:
|
||||
if self.total_error_counter.add():
|
||||
return prompt
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||
await asyncio.sleep(0.5)
|
||||
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1)
|
||||
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
result_handler=result_handler)
|
||||
else:
|
||||
self.logger.error(f"达到重试次数上限")
|
||||
return prompt
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
|
||||
async def send_prompts_async(
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
max_concurrent: int | None = None # 新增参数,默认并发数为5
|
||||
) -> list[str]:
|
||||
max_concurrent: int | None = None, # 新增参数,默认并发数为5
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
) -> list[Any]:
|
||||
max_concurrent = self.max_concurrent if max_concurrent is None else max_concurrent
|
||||
total = len(prompts)
|
||||
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
||||
@@ -186,6 +195,8 @@ class Agent:
|
||||
result = await self.send_async(
|
||||
prompt=p_text,
|
||||
system_prompt=system_prompt,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
)
|
||||
nonlocal count
|
||||
count += 1
|
||||
@@ -199,7 +210,8 @@ class Agent:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return results
|
||||
|
||||
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0) -> str:
|
||||
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
result_handler=None, error_result_handler=None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if prompt.strip() == "":
|
||||
@@ -214,12 +226,12 @@ class Agent:
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
return result
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
||||
print(f"prompt:\n{prompt}")
|
||||
self.total_error_counter.add()
|
||||
return prompt
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
except httpx.RequestError as e:
|
||||
self.logger.warning(f"AI请求连接错误 (sync): {repr(e)}\nprompt:{prompt}")
|
||||
except (KeyError, IndexError) as e:
|
||||
@@ -227,16 +239,19 @@ class Agent:
|
||||
# 如果没有正常获取结果则重试
|
||||
if retry and retry_count < MAX_RETRY_COUNT:
|
||||
if self.total_error_counter.add():
|
||||
return prompt
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||
time.sleep(0.5)
|
||||
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1)
|
||||
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
result_handler=result_handler)
|
||||
else:
|
||||
self.logger.error(f"达到重试次数上限")
|
||||
return prompt
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
|
||||
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter) -> str:
|
||||
result = self.send(prompt, system_prompt)
|
||||
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler,
|
||||
error_result_handler) -> Any:
|
||||
result = self.send(prompt, system_prompt, result_handler=result_handler,
|
||||
error_result_handler=error_result_handler)
|
||||
count.add()
|
||||
return result
|
||||
|
||||
@@ -244,14 +259,23 @@ class Agent:
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
) -> list[str]:
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None
|
||||
) -> list[Any]:
|
||||
self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}")
|
||||
self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}")
|
||||
system_prompts = [system_prompt] * len(prompts)
|
||||
counts = [PromptsCounter(len(prompts), self.logger)] * len(prompts)
|
||||
|
||||
# 创建单个计数器实例
|
||||
counter = PromptsCounter(len(prompts), self.logger)
|
||||
|
||||
# 使用 itertools.repeat 将同一个实例传递给每个 map 调用
|
||||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||
counters = itertools.repeat(counter, len(prompts))
|
||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||
output_list = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counts)
|
||||
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers)
|
||||
output_list = list(results_iterator)
|
||||
return output_list
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError
|
||||
from logging import Logger
|
||||
|
||||
from json_repair import json_repair
|
||||
|
||||
@@ -43,23 +44,38 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
||||
if config.custom_prompt:
|
||||
self.system_prompt += "\n# 重要规则或背景【非常重要】\n" + config.custom_prompt + '\n'
|
||||
|
||||
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||
try:
|
||||
result = json_repair.loads(result)
|
||||
except:
|
||||
logger.error("结果不能正确解析")
|
||||
return self._error_result_handler(origin_prompt, logger)
|
||||
return result
|
||||
|
||||
def _error_result_handler(self, origin_prompt: str, logger: Logger):
|
||||
try:
|
||||
return json_repair.loads(origin_prompt)
|
||||
except:
|
||||
logger.error("prompt不是json格式")
|
||||
return origin_prompt
|
||||
|
||||
def send_segments(self, segments: list[str], chunk_size: int):
|
||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
translated_chunks = super().send_prompts(prompts=prompts)
|
||||
translated_chunks = super().send_prompts(prompts=prompts, result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk_str in translated_chunks:
|
||||
for chunk in translated_chunks:
|
||||
try:
|
||||
translated_part = json_repair.loads(chunk_str)
|
||||
for key, val in translated_part.items():
|
||||
for key, val in chunk.items():
|
||||
if key in indexed_translated:
|
||||
indexed_translated[key] = val
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||
except ValueError as e:
|
||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
except Exception as e:
|
||||
self.logger.info(f"send_segments错误:{e.__repr__()}")
|
||||
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||
|
||||
# 初始化结果列表
|
||||
result = []
|
||||
@@ -79,28 +95,28 @@ Warning: Never wrap the entire JSON object in quotes to make it a single string.
|
||||
|
||||
# todo:增加协程粒度
|
||||
async def send_segments_async(self, segments: list[str], chunk_size: int):
|
||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks,segments, chunk_size)
|
||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||
chunk_size)
|
||||
prompts = [json.dumps(chunk, ensure_ascii=False) for chunk in chunks]
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts)
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts, result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
indexed_translated = indexed_originals.copy()
|
||||
for chunk_str in translated_chunks:
|
||||
for chunk in translated_chunks:
|
||||
try:
|
||||
translated_part = json_repair.loads(chunk_str)
|
||||
for key, val in translated_part.items():
|
||||
for key, val in chunk.items():
|
||||
if key in indexed_translated:
|
||||
indexed_translated[key] = val
|
||||
except JSONDecodeError as e:
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk_str},错误:{e.__repr__()}")
|
||||
self.logger.info(f"json解析错误,解析文本:{chunk},错误:{e.__repr__()}")
|
||||
except ValueError as e:
|
||||
self.logger.info(f"value错误,更新对象:{indexed_translated},错误:{e.__repr__()}")
|
||||
except Exception as e:
|
||||
self.logger.info(f"send_segments错误:{e.__repr__()}")
|
||||
|
||||
self.logger.info(f"send_segments发生错误:{e.__repr__()}")
|
||||
|
||||
# 初始化结果列表
|
||||
result = []
|
||||
last_end = 0
|
||||
ls=list(indexed_translated.values())
|
||||
ls = list(indexed_translated.values())
|
||||
for start, end in merged_indices_list:
|
||||
# 添加未处理的部分
|
||||
result.extend(ls[last_end:start])
|
||||
|
||||
Reference in New Issue
Block a user