提供json_format选项

This commit is contained in:
xunbu
2025-11-09 14:31:57 +08:00
parent 66015c23b2
commit 69028b2e7f
3 changed files with 32 additions and 12 deletions

View File

@@ -234,8 +234,9 @@ class Agent:
elif self.thinking == "disable":
data[field_thinking] = val_disable
def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False
):
if temperature is None:
temperature = self.temperature
@@ -254,6 +255,8 @@ class Agent:
}
if self.thinking != "default":
self._add_thinking_mode(data)
if json_format:
data["response_format"] = {"type": "json_object"}
return headers, data
async def send_async(
@@ -263,6 +266,7 @@ class Agent:
system_prompt: None | str = None,
retry=True,
retry_count=0,
json_format=False,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
@@ -274,7 +278,7 @@ class Agent:
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
# print(f"system_prompt:\n{system_prompt}")
# print(f"【测试】prompt:\n{prompt}")
headers, data = self._prepare_request_data(prompt, system_prompt)
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None
@@ -412,6 +416,7 @@ class Agent:
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None,
json_format=False,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
@@ -454,6 +459,7 @@ class Agent:
client=client,
prompt=p_text,
system_prompt=system_prompt,
json_format=json_format,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
@@ -494,6 +500,7 @@ class Agent:
system_prompt: None | str = None,
retry=True,
retry_count=0,
json_format=False,
pre_send_handler=None,
result_handler=None,
error_result_handler=None,
@@ -504,7 +511,7 @@ class Agent:
if pre_send_handler:
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
headers, data = self._prepare_request_data(prompt, system_prompt)
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
should_retry = False
is_hard_error = False # 新增标志,用于区分是否为硬错误
current_partial_result = None
@@ -638,15 +645,17 @@ class Agent:
client: httpx.Client,
prompt: str,
system_prompt: None | str,
json_format,
count: PromptsCounter,
pre_send_handler,
result_handler,
error_result_handler,
error_result_handler
) -> Any:
result = self.send(
client,
prompt,
system_prompt,
json_format=json_format,
pre_send_handler=pre_send_handler,
result_handler=result_handler,
error_result_handler=error_result_handler,
@@ -658,6 +667,7 @@ class Agent:
self,
prompts: list[str],
system_prompt: str | None = None,
json_format=False,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
@@ -680,6 +690,7 @@ class Agent:
counter = PromptsCounter(len(prompts), self.logger)
system_prompts = itertools.repeat(system_prompt, len(prompts))
json_formats = itertools.repeat(json_format, len(prompts))
counters = itertools.repeat(counter, len(prompts))
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
result_handlers = itertools.repeat(result_handler, len(prompts))
@@ -699,6 +710,7 @@ class Agent:
clients,
prompts,
system_prompts,
json_formats,
counters,
pre_send_handlers,
result_handlers,

View File

@@ -57,6 +57,7 @@ Output(target language: {to_lang}):
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
"""
def get_original_segments(prompt: str):
match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
if match:
@@ -64,6 +65,7 @@ def get_original_segments(prompt:str):
else:
raise ValueError("无法从prompt中提取初始文本")
def get_target_segments(result: str):
match = re.search(r'```json(.*)```', result, re.DOTALL)
if match:
@@ -71,17 +73,20 @@ def get_target_segments(result:str):
else:
return result
@dataclass
class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str
custom_prompt: str | None = None
glossary_dict: dict[str, str] | None = None
json_format:bool = True
class SegmentsTranslateAgent(Agent):
def __init__(self, config: SegmentsTranslateAgentConfig):
super().__init__(config)
self.to_lang = config.to_lang
self.json_format = config.json_format
self.system_prompt = f"""
# Role
- You are a professional, authentic machine translation engine.
@@ -178,7 +183,8 @@ class SegmentsTranslateAgent(Agent):
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format,
pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
error_result_handler=self._error_result_handler)
@@ -216,7 +222,8 @@ class SegmentsTranslateAgent(Agent):
chunk_size)
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format,
pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
error_result_handler=self._error_result_handler)

View File

@@ -32,6 +32,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
},
),
"siliconflow": ("enable_thinking", True, False),
"default":("reasoning_effort","medium","minimal"),
}
@@ -45,7 +46,7 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
return thinking_mode["volces"]
elif "gemini" in model_id:
return thinking_mode["google"]
return None
return thinking_mode["default"]
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
@@ -62,7 +63,7 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
return thinking_mode["siliconflow"]
elif provider == "api.302.ai":
return get_thinking_mode_by_model_id(model_id)
return None
return thinking_mode["default"]
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):