修复agent连接池内存泄漏的问题

This commit is contained in:
xunbu
2025-08-22 10:06:48 +08:00
parent 140dd438d0
commit 48b0f30261

View File

@@ -66,7 +66,6 @@ class PromptsCounter:
self.lock.release()
TIMEOUT = 600
ResultHandlerType = Callable[[str, str, logging.Logger], str]
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
@@ -92,13 +91,6 @@ class Agent:
self.model_id = config.model_id.strip()
self.system_prompt = config.system_prompt or ""
self.temperature = config.temperature
if USE_PROXY:
self.client = httpx.Client(proxies=get_httpx_proxies(), verify=False)
self.client_async = httpx.AsyncClient(proxies=get_httpx_proxies(), verify=False)
else:
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
self.max_concurrent = config.max_concurrent
self.timeout = config.timeout
self.thinking = config.thinking
@@ -133,17 +125,18 @@ class Agent:
self._add_thinking_mode(data)
return headers, data
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
retry_count=0,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if prompt.strip() == "":
return prompt
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
try:
response = await self.client_async.post(
response = await client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
@@ -167,7 +160,7 @@ class Agent:
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
self.logger.info(f"正在重试,重试次数{retry_count}")
await asyncio.sleep(0.5)
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
result_handler=result_handler)
else:
self.logger.error(f"达到重试次数上限")
@@ -189,36 +182,40 @@ class Agent:
semaphore = asyncio.Semaphore(max_concurrent)
tasks = []
proxies = get_httpx_proxies() if USE_PROXY else None
# 辅助协程,用于包装 self.send_async 并使用信号量
async def send_with_semaphore(p_text: str):
async with semaphore: # 在进入代码块前获取信号量,退出时释放
result = await self.send_async(
prompt=p_text,
system_prompt=system_prompt,
result_handler=result_handler,
error_result_handler=error_result_handler,
)
nonlocal count
count += 1
self.logger.info(f"协程-已完成{count}/{total}")
return result
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client:
async def send_with_semaphore(p_text: str):
async with semaphore: # 在进入代码块前获取信号量,退出时释放
result = await self.send_async(
client=client,
prompt=p_text,
system_prompt=system_prompt,
result_handler=result_handler,
error_result_handler=error_result_handler,
)
nonlocal count
count += 1
self.logger.info(f"协程-已完成{count}/{total}")
return result
for p_text in prompts:
task = asyncio.create_task(send_with_semaphore(p_text))
tasks.append(task)
for p_text in prompts:
task = asyncio.create_task(send_with_semaphore(p_text))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=False)
return results
results = await asyncio.gather(*tasks, return_exceptions=False)
return results
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
result_handler=None, error_result_handler=None) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
if prompt.strip() == "":
return prompt
# if prompt.strip() == "":
# return prompt
headers, data = self._prepare_request_data(prompt, system_prompt)
try:
response = self.client.post(
response = client.post(
f"{self.baseurl}/chat/completions",
json=data,
headers=headers,
@@ -228,7 +225,7 @@ class Agent:
result = response.json()["choices"][0]["message"]["content"]
return result if result_handler is None else result_handler(result, prompt, self.logger)
except httpx.HTTPStatusError as e:
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
self.logger.warning(f"AI请求错误 (sync): {e.response.status_code} - {e.response.text}")
print(f"prompt:\n{prompt}")
self.total_error_counter.add()
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
@@ -242,15 +239,16 @@ class Agent:
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
self.logger.info(f"正在重试,重试次数{retry_count}")
time.sleep(0.5)
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
result_handler=result_handler)
else:
self.logger.error(f"达到重试次数上限")
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler,
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
result_handler,
error_result_handler) -> Any:
result = self.send(prompt, system_prompt, result_handler=result_handler,
result = self.send(client, prompt, system_prompt, result_handler=result_handler,
error_result_handler=error_result_handler)
count.add()
return result
@@ -274,9 +272,14 @@ class Agent:
result_handlers = itertools.repeat(result_handler, len(prompts))
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
output_list = []
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers)
output_list = list(results_iterator)
proxies = get_httpx_proxies() if USE_PROXY else None
with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client:
clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
result_handlers,
error_result_handlers)
output_list = list(results_iterator)
return output_list