From 69028b2e7f1f0f803785fdf84e1983d9800be281 Mon Sep 17 00:00:00 2001 From: xunbu Date: Sun, 9 Nov 2025 14:31:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BE=9Bjson=5Fformat=E9=80=89?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docutranslate/agents/agent.py | 20 +++++++++++++++---- docutranslate/agents/segments_agent.py | 19 ++++++++++++------ .../agents/thinking/thinking_factory.py | 5 +++-- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index 31d9ec5..47d7966 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -234,8 +234,9 @@ class Agent: elif self.thinking == "disable": data[field_thinking] = val_disable + def _prepare_request_data( - self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 + self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False ): if temperature is None: temperature = self.temperature @@ -254,6 +255,8 @@ class Agent: } if self.thinking != "default": self._add_thinking_mode(data) + if json_format: + data["response_format"] = {"type": "json_object"} return headers, data async def send_async( @@ -263,6 +266,7 @@ class Agent: system_prompt: None | str = None, retry=True, retry_count=0, + json_format=False, pre_send_handler: PreSendHandlerType = None, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None, @@ -274,7 +278,7 @@ class Agent: system_prompt, prompt = pre_send_handler(system_prompt, prompt) # print(f"system_prompt:\n{system_prompt}") # print(f"【测试】prompt:\n{prompt}") - headers, data = self._prepare_request_data(prompt, system_prompt) + headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format) should_retry = False is_hard_error = False # 新增标志,用于区分是否为硬错误 current_partial_result = None @@ -412,6 +416,7 @@ class Agent: prompts: list[str], system_prompt: str | None = None, max_concurrent: int | None = None, + json_format=False, pre_send_handler: PreSendHandlerType = None, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None, @@ -454,6 +459,7 @@ class Agent: client=client, prompt=p_text, system_prompt=system_prompt, + json_format=json_format, pre_send_handler=pre_send_handler, result_handler=result_handler, error_result_handler=error_result_handler, @@ -494,6 +500,7 @@ class Agent: system_prompt: None | str = None, retry=True, retry_count=0, + json_format=False, pre_send_handler=None, result_handler=None, error_result_handler=None, @@ -504,7 +511,7 @@ class Agent: if pre_send_handler: system_prompt, prompt = pre_send_handler(system_prompt, prompt) - headers, data = self._prepare_request_data(prompt, system_prompt) + headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format) should_retry = False is_hard_error = False # 新增标志,用于区分是否为硬错误 current_partial_result = None @@ -638,15 +645,17 @@ class Agent: client: httpx.Client, prompt: str, system_prompt: None | str, + json_format, count: PromptsCounter, pre_send_handler, result_handler, - error_result_handler, + error_result_handler ) -> Any: result = self.send( client, prompt, system_prompt, + json_format=json_format, pre_send_handler=pre_send_handler, result_handler=result_handler, error_result_handler=error_result_handler, @@ -658,6 +667,7 @@ class Agent: self, prompts: list[str], system_prompt: str | None = None, + json_format=False, pre_send_handler: PreSendHandlerType = None, result_handler: ResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None, @@ -680,6 +690,7 @@ class Agent: counter = PromptsCounter(len(prompts), self.logger) system_prompts = itertools.repeat(system_prompt, len(prompts)) + json_formats = itertools.repeat(json_format, len(prompts)) counters = itertools.repeat(counter, len(prompts)) pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts)) result_handlers = itertools.repeat(result_handler, len(prompts)) @@ -699,6 +710,7 @@ class Agent: clients, prompts, system_prompts, + json_formats, counters, pre_send_handlers, result_handlers, diff --git a/docutranslate/agents/segments_agent.py b/docutranslate/agents/segments_agent.py index 351b3d0..17c4145 100644 --- a/docutranslate/agents/segments_agent.py +++ b/docutranslate/agents/segments_agent.py @@ -57,31 +57,36 @@ Output(target language: {to_lang}): Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are. """ -def get_original_segments(prompt:str): + +def get_original_segments(prompt: str): match = re.search(r'\n```json\n(.*)\n```\n', prompt, re.DOTALL) if match: return match.group(1) else: raise ValueError("无法从prompt中提取初始文本") -def get_target_segments(result:str): + +def get_target_segments(result: str): match = re.search(r'```json(.*)```', result, re.DOTALL) if match: return match.group(1) else: return result + @dataclass class SegmentsTranslateAgentConfig(AgentConfig): to_lang: str custom_prompt: str | None = None glossary_dict: dict[str, str] | None = None + json_format:bool = True class SegmentsTranslateAgent(Agent): def __init__(self, config: SegmentsTranslateAgentConfig): super().__init__(config) self.to_lang = config.to_lang + self.json_format = config.json_format self.system_prompt = f""" # Role - You are a professional, authentic machine translation engine. @@ -104,7 +109,7 @@ class SegmentsTranslateAgent(Agent): - 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。 - 其他错误(如JSON解析失败、模型偷懒)则抛出普通 ValueError 触发重试。 """ - original_segments=get_original_segments(origin_prompt) + original_segments = get_original_segments(origin_prompt) result = get_target_segments(result) if result == "": if original_segments.strip() != "": @@ -160,7 +165,7 @@ class SegmentsTranslateAgent(Agent): 处理在所有重试后仍然失败的请求。 作为备用方案,返回原文内容,并将所有值转换为字符串。 """ - original_segments=get_original_segments(origin_prompt) + original_segments = get_original_segments(origin_prompt) if original_segments == "": return {} try: @@ -178,7 +183,8 @@ class SegmentsTranslateAgent(Agent): indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks] - translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler, + translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format, + pre_send_handler=self._pre_send_handler, result_handler=self._result_handler, error_result_handler=self._error_result_handler) @@ -216,7 +222,8 @@ class SegmentsTranslateAgent(Agent): chunk_size) prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks] - translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler, + translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format, + pre_send_handler=self._pre_send_handler, result_handler=self._result_handler, error_result_handler=self._error_result_handler) diff --git a/docutranslate/agents/thinking/thinking_factory.py b/docutranslate/agents/thinking/thinking_factory.py index 1b6c61c..a1aff52 100644 --- a/docutranslate/agents/thinking/thinking_factory.py +++ b/docutranslate/agents/thinking/thinking_factory.py @@ -32,6 +32,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]] }, ), "siliconflow": ("enable_thinking", True, False), + "default":("reasoning_effort","medium","minimal"), } @@ -45,7 +46,7 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str | return thinking_mode["volces"] elif "gemini" in model_id: return thinking_mode["google"] - return None + return thinking_mode["default"] def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None: @@ -62,7 +63,7 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st return thinking_mode["siliconflow"] elif provider == "api.302.ai": return get_thinking_mode_by_model_id(model_id) - return None + return thinking_mode["default"] # def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):