diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 7411056..288a04b 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -359,6 +359,94 @@ class Agent: data["response_format"] = {"type": "json_object"} return headers, data + async def _continue_fetch_async( + self, + client: httpx.AsyncClient, + prompt: str, + system_prompt: str, + force_json: bool, + pre_send_handler: PreSendHandlerType, + result_handler: ResultHandlerType, + error_result_handler: ErrorResultHandlerType, + retry_count: int, + accumulated_result: str = "", + ) -> Any: + """ + 当 finish_reason 为 length 时,继续获取剩余内容 + """ + self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...") + + # 使用空内容继续请求,实际上多数 API 需要用户提供已获取的内容作为上下文 + # 这里我们发送一个继续信号,让模型继续输出 + continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]" + + if pre_send_handler: + system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt) + + # 速率限制检查 + estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt) + await self.rate_limiter.acquire_async(tokens=estimated_tokens) + + headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json) + + try: + response = await client.post( + f"{self.baseurl}/chat/completions", + json=data, + headers=headers, + timeout=self.timeout, + ) + response.raise_for_status() + response_data = response.json() + + finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None) + additional_result = response_data["choices"][0]["message"]["content"] + + input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( + extract_token_info(response_data) + ) + self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens) + + # 累加结果 + accumulated_result += additional_result + + # 如果仍然是 length,继续获取 + if finish_reason == "length": + return await self._continue_fetch_async( + client=client, + prompt=prompt, + system_prompt=system_prompt, + force_json=force_json, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + retry_count=retry_count, + accumulated_result=accumulated_result, + ) + + # 非 length 结束,返回累加结果 + return ( + accumulated_result + if result_handler is None + else result_handler(accumulated_result, prompt, self.logger) + ) + + except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e: + self.logger.error(f"继续获取内容失败: {repr(e)}") + # 返回已获取的部分结果 + if accumulated_result: + return ( + accumulated_result + if result_handler is None + else result_handler(accumulated_result, prompt, self.logger) + ) + # 如果没有部分结果,调用错误处理器 + return ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) + async def send_async( self, client: httpx.AsyncClient, @@ -398,9 +486,31 @@ class Agent: timeout=self.timeout, ) response.raise_for_status() - result = response.json()["choices"][0]["message"]["content"] - response_data = response.json() + + # 检查 finish_reason + finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None) + if finish_reason != "stop": + # 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等 + self.logger.warning( + f"finish_reason 为 '{finish_reason}',非正常结束。prompt: {prompt[:100]}..." + ) + + # 如果是长度限制,尝试继续获取 + if finish_reason == "length": + return await self._continue_fetch_async( + client=client, + prompt=prompt, + system_prompt=system_prompt, + force_json=force_json, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + retry_count=retry_count, + ) + + result = response_data["choices"][0]["message"]["content"] + input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( extract_token_info(response_data) ) @@ -593,6 +703,86 @@ class Agent: return results + def _continue_fetch( + self, + client: httpx.Client, + prompt: str, + system_prompt: str, + force_json: bool, + pre_send_handler, + result_handler, + error_result_handler, + retry_count: int, + accumulated_result: str = "", + ) -> Any: + """ + 当 finish_reason 为 length 时,继续获取剩余内容(同步版本) + """ + self.logger.info(f"继续获取剩余内容 (已累计 {len(accumulated_result)} 字符)...") + + continue_prompt = f"{prompt}\n\n[系统提示:之前的响应被截断,请继续输出剩余内容。]" + + if pre_send_handler: + system_prompt, continue_prompt = pre_send_handler(system_prompt, continue_prompt) + + estimated_tokens = self._estimate_tokens(system_prompt) + self._estimate_tokens(continue_prompt) + self.rate_limiter.acquire_sync(tokens=estimated_tokens) + + headers, data = self._prepare_request_data(continue_prompt, system_prompt, json_format=force_json) + + try: + response = client.post( + f"{self.baseurl}/chat/completions", + json=data, + headers=headers, + timeout=self.timeout, + ) + response.raise_for_status() + response_data = response.json() + + finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None) + additional_result = response_data["choices"][0]["message"]["content"] + + input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( + extract_token_info(response_data) + ) + self.token_counter.add(input_tokens, cached_tokens, output_tokens, reasoning_tokens) + + accumulated_result += additional_result + + if finish_reason == "length": + return self._continue_fetch( + client=client, + prompt=prompt, + system_prompt=system_prompt, + force_json=force_json, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + retry_count=retry_count, + accumulated_result=accumulated_result, + ) + + return ( + accumulated_result + if result_handler is None + else result_handler(accumulated_result, prompt, self.logger) + ) + + except (httpx.HTTPStatusError, httpx.RequestError, KeyError, IndexError, ValueError) as e: + self.logger.error(f"继续获取内容失败: {repr(e)}") + if accumulated_result: + return ( + accumulated_result + if result_handler is None + else result_handler(accumulated_result, prompt, self.logger) + ) + return ( + prompt + if error_result_handler is None + else error_result_handler(prompt, self.logger) + ) + def send( self, client: httpx.Client, @@ -630,10 +820,31 @@ class Agent: timeout=self.timeout, ) response.raise_for_status() - - result = response.json()["choices"][0]["message"]["content"] - response_data = response.json() + + # 检查 finish_reason + finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", None) + if finish_reason != "stop": + # 非正常结束,可能是 length (长度限制)、tool_calls、content_filter 等 + self.logger.warning( + f"finish_reason 为 '{finish_reason}',非正常结束。prompt: {prompt[:100]}..." + ) + + # 如果是长度限制,尝试继续获取 + if finish_reason == "length": + return self._continue_fetch( + client=client, + prompt=prompt, + system_prompt=system_prompt, + force_json=force_json, + pre_send_handler=pre_send_handler, + result_handler=result_handler, + error_result_handler=error_result_handler, + retry_count=retry_count, + ) + + result = response_data["choices"][0]["message"]["content"] + input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( extract_token_info(response_data) )