修复agent连接池内存泄漏的问题
This commit is contained in:
@@ -66,7 +66,6 @@ class PromptsCounter:
|
||||
self.lock.release()
|
||||
|
||||
|
||||
TIMEOUT = 600
|
||||
|
||||
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
||||
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
||||
@@ -92,13 +91,6 @@ class Agent:
|
||||
self.model_id = config.model_id.strip()
|
||||
self.system_prompt = config.system_prompt or ""
|
||||
self.temperature = config.temperature
|
||||
if USE_PROXY:
|
||||
self.client = httpx.Client(proxies=get_httpx_proxies(), verify=False)
|
||||
self.client_async = httpx.AsyncClient(proxies=get_httpx_proxies(), verify=False)
|
||||
else:
|
||||
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
|
||||
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
||||
|
||||
self.max_concurrent = config.max_concurrent
|
||||
self.timeout = config.timeout
|
||||
self.thinking = config.thinking
|
||||
@@ -133,17 +125,18 @@ class Agent:
|
||||
self._add_thinking_mode(data)
|
||||
return headers, data
|
||||
|
||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
||||
retry_count=0,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if prompt.strip() == "":
|
||||
return prompt
|
||||
# if prompt.strip() == "":
|
||||
# return prompt
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
|
||||
try:
|
||||
response = await self.client_async.post(
|
||||
response = await client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
@@ -167,7 +160,7 @@ class Agent:
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||
await asyncio.sleep(0.5)
|
||||
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
result_handler=result_handler)
|
||||
else:
|
||||
self.logger.error(f"达到重试次数上限")
|
||||
@@ -189,36 +182,40 @@ class Agent:
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
tasks = []
|
||||
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
|
||||
# 辅助协程,用于包装 self.send_async 并使用信号量
|
||||
async def send_with_semaphore(p_text: str):
|
||||
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
||||
result = await self.send_async(
|
||||
prompt=p_text,
|
||||
system_prompt=system_prompt,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
)
|
||||
nonlocal count
|
||||
count += 1
|
||||
self.logger.info(f"协程-已完成{count}/{total}")
|
||||
return result
|
||||
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client:
|
||||
async def send_with_semaphore(p_text: str):
|
||||
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
||||
result = await self.send_async(
|
||||
client=client,
|
||||
prompt=p_text,
|
||||
system_prompt=system_prompt,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
)
|
||||
nonlocal count
|
||||
count += 1
|
||||
self.logger.info(f"协程-已完成{count}/{total}")
|
||||
return result
|
||||
|
||||
for p_text in prompts:
|
||||
task = asyncio.create_task(send_with_semaphore(p_text))
|
||||
tasks.append(task)
|
||||
for p_text in prompts:
|
||||
task = asyncio.create_task(send_with_semaphore(p_text))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return results
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return results
|
||||
|
||||
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||
result_handler=None, error_result_handler=None) -> Any:
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if prompt.strip() == "":
|
||||
return prompt
|
||||
# if prompt.strip() == "":
|
||||
# return prompt
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
try:
|
||||
response = self.client.post(
|
||||
response = client.post(
|
||||
f"{self.baseurl}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
@@ -228,7 +225,7 @@ class Agent:
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
||||
self.logger.warning(f"AI请求错误 (sync): {e.response.status_code} - {e.response.text}")
|
||||
print(f"prompt:\n{prompt}")
|
||||
self.total_error_counter.add()
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
@@ -242,15 +239,16 @@ class Agent:
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||
time.sleep(0.5)
|
||||
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||
result_handler=result_handler)
|
||||
else:
|
||||
self.logger.error(f"达到重试次数上限")
|
||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||
|
||||
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler,
|
||||
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
|
||||
result_handler,
|
||||
error_result_handler) -> Any:
|
||||
result = self.send(prompt, system_prompt, result_handler=result_handler,
|
||||
result = self.send(client, prompt, system_prompt, result_handler=result_handler,
|
||||
error_result_handler=error_result_handler)
|
||||
count.add()
|
||||
return result
|
||||
@@ -274,9 +272,14 @@ class Agent:
|
||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||
output_list = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers)
|
||||
output_list = list(results_iterator)
|
||||
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||
with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client:
|
||||
clients = itertools.repeat(client, len(prompts))
|
||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
||||
result_handlers,
|
||||
error_result_handlers)
|
||||
output_list = list(results_iterator)
|
||||
return output_list
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user