提供json_format选项

This commit is contained in:
xunbu
2025-11-09 14:31:57 +08:00
parent 66015c23b2
commit 69028b2e7f
3 changed files with 32 additions and 12 deletions

View File

@@ -234,8 +234,9 @@ class Agent:
elif self.thinking == "disable": elif self.thinking == "disable":
data[field_thinking] = val_disable data[field_thinking] = val_disable
def _prepare_request_data( def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False
): ):
if temperature is None: if temperature is None:
temperature = self.temperature temperature = self.temperature
@@ -254,6 +255,8 @@ class Agent:
} }
if self.thinking != "default": if self.thinking != "default":
self._add_thinking_mode(data) self._add_thinking_mode(data)
if json_format:
data["response_format"] = {"type": "json_object"}
return headers, data return headers, data
async def send_async( async def send_async(
@@ -263,6 +266,7 @@ class Agent:
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
json_format=False,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
@@ -274,7 +278,7 @@ class Agent:
system_prompt, prompt = pre_send_handler(system_prompt, prompt) system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# print(f"system_prompt:\n{system_prompt}") # print(f"system_prompt:\n{system_prompt}")
# print(f"【测试】prompt:\n{prompt}") # print(f"【测试】prompt:\n{prompt}")
headers, data = self._prepare_request_data(prompt, system_prompt) headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
should_retry = False should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误 is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None current_partial_result = None
@@ -412,6 +416,7 @@ class Agent:
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
max_concurrent: int | None = None, max_concurrent: int | None = None,
json_format=False,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
@@ -454,6 +459,7 @@ class Agent:
client=client, client=client,
prompt=p_text, prompt=p_text,
system_prompt=system_prompt, system_prompt=system_prompt,
json_format=json_format,
pre_send_handler=pre_send_handler, pre_send_handler=pre_send_handler,
result_handler=result_handler, result_handler=result_handler,
error_result_handler=error_result_handler, error_result_handler=error_result_handler,
@@ -494,6 +500,7 @@ class Agent:
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
json_format=False,
pre_send_handler=None, pre_send_handler=None,
result_handler=None, result_handler=None,
error_result_handler=None, error_result_handler=None,
@@ -504,7 +511,7 @@ class Agent:
if pre_send_handler: if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt) system_prompt, prompt = pre_send_handler(system_prompt, prompt)
headers, data = self._prepare_request_data(prompt, system_prompt) headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
should_retry = False should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误 is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None current_partial_result = None
@@ -638,15 +645,17 @@ class Agent:
client: httpx.Client, client: httpx.Client,
prompt: str, prompt: str,
system_prompt: None | str, system_prompt: None | str,
json_format,
count: PromptsCounter, count: PromptsCounter,
pre_send_handler, pre_send_handler,
result_handler, result_handler,
error_result_handler, error_result_handler
) -> Any: ) -> Any:
result = self.send( result = self.send(
client, client,
prompt, prompt,
system_prompt, system_prompt,
json_format=json_format,
pre_send_handler=pre_send_handler, pre_send_handler=pre_send_handler,
result_handler=result_handler, result_handler=result_handler,
error_result_handler=error_result_handler, error_result_handler=error_result_handler,
@@ -658,6 +667,7 @@ class Agent:
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
json_format=False,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
@@ -680,6 +690,7 @@ class Agent:
counter = PromptsCounter(len(prompts), self.logger) counter = PromptsCounter(len(prompts), self.logger)
system_prompts = itertools.repeat(system_prompt, len(prompts)) system_prompts = itertools.repeat(system_prompt, len(prompts))
json_formats = itertools.repeat(json_format, len(prompts))
counters = itertools.repeat(counter, len(prompts)) counters = itertools.repeat(counter, len(prompts))
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts)) pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts)) result_handlers = itertools.repeat(result_handler, len(prompts))
@@ -699,6 +710,7 @@ class Agent:
clients, clients,
prompts, prompts,
system_prompts, system_prompts,
json_formats,
counters, counters,
pre_send_handlers, pre_send_handlers,
result_handlers, result_handlers,

View File

@@ -57,31 +57,36 @@ Output(target language: {to_lang}):
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are. Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
""" """
def get_original_segments(prompt:str):
def get_original_segments(prompt: str):
match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL) match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
if match: if match:
return match.group(1) return match.group(1)
else: else:
raise ValueError("无法从prompt中提取初始文本") raise ValueError("无法从prompt中提取初始文本")
def get_target_segments(result:str):
def get_target_segments(result: str):
match = re.search(r'```json(.*)```', result, re.DOTALL) match = re.search(r'```json(.*)```', result, re.DOTALL)
if match: if match:
return match.group(1) return match.group(1)
else: else:
return result return result
@dataclass @dataclass
class SegmentsTranslateAgentConfig(AgentConfig): class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str to_lang: str
custom_prompt: str | None = None custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None glossary_dict: dict[str, str] | None = None
json_format:bool = True
class SegmentsTranslateAgent(Agent): class SegmentsTranslateAgent(Agent):
def __init__(self, config: SegmentsTranslateAgentConfig): def __init__(self, config: SegmentsTranslateAgentConfig):
super().__init__(config) super().__init__(config)
self.to_lang = config.to_lang self.to_lang = config.to_lang
self.json_format = config.json_format
self.system_prompt = f""" self.system_prompt = f"""
# Role # Role
- You are a professional, authentic machine translation engine. - You are a professional, authentic machine translation engine.
@@ -104,7 +109,7 @@ class SegmentsTranslateAgent(Agent):
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。 - 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
- 其他错误如JSON解析失败、模型偷懒则抛出普通 ValueError 触发重试。 - 其他错误如JSON解析失败、模型偷懒则抛出普通 ValueError 触发重试。
""" """
original_segments=get_original_segments(origin_prompt) original_segments = get_original_segments(origin_prompt)
result = get_target_segments(result) result = get_target_segments(result)
if result == "": if result == "":
if original_segments.strip() != "": if original_segments.strip() != "":
@@ -160,7 +165,7 @@ class SegmentsTranslateAgent(Agent):
处理在所有重试后仍然失败的请求。 处理在所有重试后仍然失败的请求。
作为备用方案,返回原文内容,并将所有值转换为字符串。 作为备用方案,返回原文内容,并将所有值转换为字符串。
""" """
original_segments=get_original_segments(origin_prompt) original_segments = get_original_segments(origin_prompt)
if original_segments == "": if original_segments == "":
return {} return {}
try: try:
@@ -178,7 +183,8 @@ class SegmentsTranslateAgent(Agent):
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks] prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format,
pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)
@@ -216,7 +222,8 @@ class SegmentsTranslateAgent(Agent):
chunk_size) chunk_size)
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks] prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format,
pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
error_result_handler=self._error_result_handler) error_result_handler=self._error_result_handler)

View File

@@ -32,6 +32,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
}, },
), ),
"siliconflow": ("enable_thinking", True, False), "siliconflow": ("enable_thinking", True, False),
"default":("reasoning_effort","medium","minimal"),
} }
@@ -45,7 +46,7 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
return thinking_mode["volces"] return thinking_mode["volces"]
elif "gemini" in model_id: elif "gemini" in model_id:
return thinking_mode["google"] return thinking_mode["google"]
return None return thinking_mode["default"]
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None: def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
@@ -62,7 +63,7 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
return thinking_mode["siliconflow"] return thinking_mode["siliconflow"]
elif provider == "api.302.ai": elif provider == "api.302.ai":
return get_thinking_mode_by_model_id(model_id) return get_thinking_mode_by_model_id(model_id)
return None return thinking_mode["default"]
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool): # def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):