提供json_format选项

This commit is contained in:
xunbu
2025-11-09 14:31:57 +08:00
parent 66015c23b2
commit 69028b2e7f
3 changed files with 32 additions and 12 deletions

View File

@@ -234,8 +234,9 @@ class Agent:
elif self.thinking == "disable": elif self.thinking == "disable":
data[field_thinking] = val_disable data[field_thinking] = val_disable
def _prepare_request_data( def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False
): ):
if temperature is None: if temperature is None:
temperature = self.temperature temperature = self.temperature
@@ -254,6 +255,8 @@ class Agent:
} }
if self.thinking != "default": if self.thinking != "default":
self._add_thinking_mode(data) self._add_thinking_mode(data)
if json_format:
data["response_format"] = {"type": "json_object"}
return headers, data return headers, data
async def send_async( async def send_async(
@@ -263,6 +266,7 @@ class Agent:
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
json_format=False,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
@@ -274,7 +278,7 @@ class Agent:
system_prompt, prompt = pre_send_handler(system_prompt, prompt) system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# print(f"system_prompt:\n{system_prompt}") # print(f"system_prompt:\n{system_prompt}")
# print(f"【测试】prompt:\n{prompt}") # print(f"【测试】prompt:\n{prompt}")
headers, data = self._prepare_request_data(prompt, system_prompt) headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
should_retry = False should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误 is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None current_partial_result = None
@@ -412,6 +416,7 @@ class Agent:
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
max_concurrent: int | None = None, max_concurrent: int | None = None,
json_format=False,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
@@ -454,6 +459,7 @@ class Agent:
client=client, client=client,
prompt=p_text, prompt=p_text,
system_prompt=system_prompt, system_prompt=system_prompt,
json_format=json_format,
pre_send_handler=pre_send_handler, pre_send_handler=pre_send_handler,
result_handler=result_handler, result_handler=result_handler,
error_result_handler=error_result_handler, error_result_handler=error_result_handler,
@@ -494,6 +500,7 @@ class Agent:
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
json_format=False,
pre_send_handler=None, pre_send_handler=None,
result_handler=None, result_handler=None,
error_result_handler=None, error_result_handler=None,
@@ -504,7 +511,7 @@ class Agent:
if pre_send_handler: if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt) system_prompt, prompt = pre_send_handler(system_prompt, prompt)
headers, data = self._prepare_request_data(prompt, system_prompt) headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
should_retry = False should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误 is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None current_partial_result = None
@@ -638,15 +645,17 @@ class Agent:
client: httpx.Client, client: httpx.Client,
prompt: str, prompt: str,
system_prompt: None | str, system_prompt: None | str,
json_format,
count: PromptsCounter, count: PromptsCounter,
pre_send_handler, pre_send_handler,
result_handler, result_handler,
error_result_handler, error_result_handler
) -> Any: ) -> Any:
result = self.send( result = self.send(
client, client,
prompt, prompt,
system_prompt, system_prompt,
json_format=json_format,
pre_send_handler=pre_send_handler, pre_send_handler=pre_send_handler,
result_handler=result_handler, result_handler=result_handler,
error_result_handler=error_result_handler, error_result_handler=error_result_handler,
@@ -658,6 +667,7 @@ class Agent:
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
json_format=False,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
@@ -680,6 +690,7 @@ class Agent:
counter = PromptsCounter(len(prompts), self.logger) counter = PromptsCounter(len(prompts), self.logger)
system_prompts = itertools.repeat(system_prompt, len(prompts)) system_prompts = itertools.repeat(system_prompt, len(prompts))
json_formats = itertools.repeat(json_format, len(prompts))
counters = itertools.repeat(counter, len(prompts)) counters = itertools.repeat(counter, len(prompts))
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts)) pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts)) result_handlers = itertools.repeat(result_handler, len(prompts))
@@ -699,6 +710,7 @@ class Agent:
clients, clients,
prompts, prompts,
system_prompts, system_prompts,
json_formats,
counters, counters,
pre_send_handlers, pre_send_handlers,
result_handlers, result_handlers,

View File

@@ -57,6 +57,7 @@ Output(target language: {to_lang}):
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are. Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
""" """
def get_original_segments(prompt: str): def get_original_segments(prompt: str):
match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL) match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
if match: if match:
@@ -64,6 +65,7 @@ def get_original_segments(prompt:str):
else: else:
raise ValueError("无法从prompt中提取初始文本") raise ValueError("无法从prompt中提取初始文本")
def get_target_segments(result: str): def get_target_segments(result: str):
match = re.search(r'```json(.*)```', result, re.DOTALL) match = re.search(r'```json(.*)```', result, re.DOTALL)
if match: if match:
@@ -71,17 +73,20 @@ def get_target_segments(result:str):
else: else:
return result return result
@dataclass @dataclass
class SegmentsTranslateAgentConfig(AgentConfig): class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str to_lang: str
custom_prompt: str | None = None custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None glossary_dict: dict[str, str] | None = None
json_format:bool = True
class SegmentsTranslateAgent(Agent): class SegmentsTranslateAgent(Agent):
def __init__(self, config: SegmentsTranslateAgentConfig): def __init__(self, config: SegmentsTranslateAgentConfig):
super().__init__(config) super().__init__(config)
self.to_lang = config.to_lang self.to_lang = config.to_lang
self.json_format = config.json_format
self.system_prompt = f""" self.system_prompt = f"""
# Role # Role
- You are a professional, authentic machine translation engine. - You are a professional, authentic machine translation engine.
@@ -178,7 +183,8 @@ class SegmentsTranslateAgent(Agent):
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks] prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format,
pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)
@@ -216,7 +222,8 @@ class SegmentsTranslateAgent(Agent):
chunk_size) chunk_size)
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks] prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format,
pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)

View File

@@ -32,6 +32,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
}, },
), ),
"siliconflow": ("enable_thinking", True, False), "siliconflow": ("enable_thinking", True, False),
"default":("reasoning_effort","medium","minimal"),
} }
@@ -45,7 +46,7 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
return thinking_mode["volces"] return thinking_mode["volces"]
elif "gemini" in model_id: elif "gemini" in model_id:
return thinking_mode["google"] return thinking_mode["google"]
return None return thinking_mode["default"]
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None: def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
@@ -62,7 +63,7 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
return thinking_mode["siliconflow"] return thinking_mode["siliconflow"]
elif provider == "api.302.ai": elif provider == "api.302.ai":
return get_thinking_mode_by_model_id(model_id) return get_thinking_mode_by_model_id(model_id)
return None return thinking_mode["default"]
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool): # def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):