修复agent连接池内存泄漏的问题
This commit is contained in:
@@ -66,7 +66,6 @@ class PromptsCounter:
|
|||||||
self.lock.release()
|
self.lock.release()
|
||||||
|
|
||||||
|
|
||||||
TIMEOUT = 600
|
|
||||||
|
|
||||||
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
ResultHandlerType = Callable[[str, str, logging.Logger], str]
|
||||||
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
ErrorResultHandlerType = Callable[[str, logging.Logger], str]
|
||||||
@@ -92,13 +91,6 @@ class Agent:
|
|||||||
self.model_id = config.model_id.strip()
|
self.model_id = config.model_id.strip()
|
||||||
self.system_prompt = config.system_prompt or ""
|
self.system_prompt = config.system_prompt or ""
|
||||||
self.temperature = config.temperature
|
self.temperature = config.temperature
|
||||||
if USE_PROXY:
|
|
||||||
self.client = httpx.Client(proxies=get_httpx_proxies(), verify=False)
|
|
||||||
self.client_async = httpx.AsyncClient(proxies=get_httpx_proxies(), verify=False)
|
|
||||||
else:
|
|
||||||
self.client = httpx.Client(trust_env=False, proxy=None, verify=False)
|
|
||||||
self.client_async = httpx.AsyncClient(trust_env=False, proxy=None, verify=False)
|
|
||||||
|
|
||||||
self.max_concurrent = config.max_concurrent
|
self.max_concurrent = config.max_concurrent
|
||||||
self.timeout = config.timeout
|
self.timeout = config.timeout
|
||||||
self.thinking = config.thinking
|
self.thinking = config.thinking
|
||||||
@@ -133,17 +125,18 @@ class Agent:
|
|||||||
self._add_thinking_mode(data)
|
self._add_thinking_mode(data)
|
||||||
return headers, data
|
return headers, data
|
||||||
|
|
||||||
async def send_async(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
async def send_async(self, client: httpx.AsyncClient, prompt: str, system_prompt: None | str = None, retry=True,
|
||||||
|
retry_count=0,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
error_result_handler: ErrorResultHandlerType = None) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
if prompt.strip() == "":
|
# if prompt.strip() == "":
|
||||||
return prompt
|
# return prompt
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.client_async.post(
|
response = await client.post(
|
||||||
f"{self.baseurl}/chat/completions",
|
f"{self.baseurl}/chat/completions",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -167,7 +160,7 @@ class Agent:
|
|||||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
return await self.send_async(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
return await self.send_async(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||||
result_handler=result_handler)
|
result_handler=result_handler)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"达到重试次数上限")
|
self.logger.error(f"达到重试次数上限")
|
||||||
@@ -189,10 +182,14 @@ class Agent:
|
|||||||
semaphore = asyncio.Semaphore(max_concurrent)
|
semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
|
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||||
|
|
||||||
# 辅助协程,用于包装 self.send_async 并使用信号量
|
# 辅助协程,用于包装 self.send_async 并使用信号量
|
||||||
|
async with httpx.AsyncClient(trust_env=False, proxies=proxies, verify=False) as client:
|
||||||
async def send_with_semaphore(p_text: str):
|
async def send_with_semaphore(p_text: str):
|
||||||
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
async with semaphore: # 在进入代码块前获取信号量,退出时释放
|
||||||
result = await self.send_async(
|
result = await self.send_async(
|
||||||
|
client=client,
|
||||||
prompt=p_text,
|
prompt=p_text,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
@@ -210,15 +207,15 @@ class Agent:
|
|||||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def send(self, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
def send(self, client: httpx.Client, prompt: str, system_prompt: None | str = None, retry=True, retry_count=0,
|
||||||
result_handler=None, error_result_handler=None) -> Any:
|
result_handler=None, error_result_handler=None) -> Any:
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.system_prompt
|
system_prompt = self.system_prompt
|
||||||
if prompt.strip() == "":
|
# if prompt.strip() == "":
|
||||||
return prompt
|
# return prompt
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||||
try:
|
try:
|
||||||
response = self.client.post(
|
response = client.post(
|
||||||
f"{self.baseurl}/chat/completions",
|
f"{self.baseurl}/chat/completions",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -228,7 +225,7 @@ class Agent:
|
|||||||
result = response.json()["choices"][0]["message"]["content"]
|
result = response.json()["choices"][0]["message"]["content"]
|
||||||
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
return result if result_handler is None else result_handler(result, prompt, self.logger)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
self.logger.warning(f"AI请求错误 (async): {e.response.status_code} - {e.response.text}")
|
self.logger.warning(f"AI请求错误 (sync): {e.response.status_code} - {e.response.text}")
|
||||||
print(f"prompt:\n{prompt}")
|
print(f"prompt:\n{prompt}")
|
||||||
self.total_error_counter.add()
|
self.total_error_counter.add()
|
||||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
@@ -242,15 +239,16 @@ class Agent:
|
|||||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
self.logger.info(f"正在重试,重试次数{retry_count}")
|
self.logger.info(f"正在重试,重试次数{retry_count}")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
return self.send(prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
return self.send(client, prompt, system_prompt, retry=True, retry_count=retry_count + 1,
|
||||||
result_handler=result_handler)
|
result_handler=result_handler)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"达到重试次数上限")
|
self.logger.error(f"达到重试次数上限")
|
||||||
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
return prompt if error_result_handler is None else error_result_handler(prompt, self.logger)
|
||||||
|
|
||||||
def _send_prompt_count(self, prompt: str, system_prompt: None | str, count: PromptsCounter, result_handler,
|
def _send_prompt_count(self, client: httpx.Client, prompt: str, system_prompt: None | str, count: PromptsCounter,
|
||||||
|
result_handler,
|
||||||
error_result_handler) -> Any:
|
error_result_handler) -> Any:
|
||||||
result = self.send(prompt, system_prompt, result_handler=result_handler,
|
result = self.send(client, prompt, system_prompt, result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler)
|
error_result_handler=error_result_handler)
|
||||||
count.add()
|
count.add()
|
||||||
return result
|
return result
|
||||||
@@ -274,8 +272,13 @@ class Agent:
|
|||||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||||
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
error_result_handlers = itertools.repeat(error_result_handler, len(prompts))
|
||||||
output_list = []
|
output_list = []
|
||||||
|
proxies = get_httpx_proxies() if USE_PROXY else None
|
||||||
|
with httpx.Client(trust_env=False, proxies=proxies, verify=False) as client:
|
||||||
|
clients = itertools.repeat(client, len(prompts))
|
||||||
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
|
||||||
results_iterator = executor.map(self._send_prompt_count, prompts, system_prompts, counters, result_handlers,error_result_handlers)
|
results_iterator = executor.map(self._send_prompt_count, clients, prompts, system_prompts, counters,
|
||||||
|
result_handlers,
|
||||||
|
error_result_handlers)
|
||||||
output_list = list(results_iterator)
|
output_list = list(results_iterator)
|
||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user