diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 1d59082..5059ae7 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -15,7 +15,7 @@ from docutranslate.logger import global_logger from docutranslate.utils.utils import get_httpx_proxies MAX_RETRY_COUNT = 2 -MAX_TOTAL_ERROR_COUNT = 10 +MAX_REQUESTS_PER_ERROR = 30 ThinkingMode = Literal["enable", "disable", "default"] @@ -34,21 +34,22 @@ class AgentConfig: class TotalErrorCounter: - def __init__(self, logger: logging.Logger): + def __init__(self, logger: logging.Logger,max_errors_count=10): self.lock = Lock() self.count = 0 self.logger = logger + self.max_errors_count=max_errors_count def add(self): self.lock.acquire() self.count += 1 - if self.count > MAX_TOTAL_ERROR_COUNT: + if self.count > self.max_errors_count: self.logger.info(f"错误响应过多") self.lock.release() return self.reach_limit() def reach_limit(self): - return self.count > MAX_TOTAL_ERROR_COUNT + return self.count > self.max_errors_count # 仅使用多线程时用以计数 @@ -178,6 +179,7 @@ class Agent: total = len(prompts) self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") + self.total_error_counter.max_errors_count=len(prompts) // MAX_REQUESTS_PER_ERROR #允许多少个异常 count = 0 semaphore = asyncio.Semaphore(max_concurrent) tasks = [] @@ -262,7 +264,7 @@ class Agent: ) -> list[Any]: self.logger.info(f"base-url:{self.baseurl},model-id:{self.model_id}") self.logger.info(f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}") - + self.total_error_counter.max_errors_count = len(prompts) // MAX_REQUESTS_PER_ERROR # 允许多少个异常 # 创建单个计数器实例 counter = PromptsCounter(len(prompts), self.logger)