提供json_format选项
This commit is contained in:
@@ -234,8 +234,9 @@ class Agent:
|
||||
elif self.thinking == "disable":
|
||||
data[field_thinking] = val_disable
|
||||
|
||||
|
||||
def _prepare_request_data(
|
||||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
|
||||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False
|
||||
):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
@@ -254,6 +255,8 @@ class Agent:
|
||||
}
|
||||
if self.thinking != "default":
|
||||
self._add_thinking_mode(data)
|
||||
if json_format:
|
||||
data["response_format"] = {"type": "json_object"}
|
||||
return headers, data
|
||||
|
||||
async def send_async(
|
||||
@@ -263,6 +266,7 @@ class Agent:
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
json_format=False,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
@@ -274,7 +278,7 @@ class Agent:
|
||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||
# print(f"system_prompt:\n{system_prompt}")
|
||||
# print(f"【测试】prompt:\n{prompt}")
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
|
||||
should_retry = False
|
||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||
current_partial_result = None
|
||||
@@ -412,6 +416,7 @@ class Agent:
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
max_concurrent: int | None = None,
|
||||
json_format=False,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
@@ -454,6 +459,7 @@ class Agent:
|
||||
client=client,
|
||||
prompt=p_text,
|
||||
system_prompt=system_prompt,
|
||||
json_format=json_format,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
@@ -494,6 +500,7 @@ class Agent:
|
||||
system_prompt: None | str = None,
|
||||
retry=True,
|
||||
retry_count=0,
|
||||
json_format=False,
|
||||
pre_send_handler=None,
|
||||
result_handler=None,
|
||||
error_result_handler=None,
|
||||
@@ -504,7 +511,7 @@ class Agent:
|
||||
if pre_send_handler:
|
||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
||||
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
|
||||
should_retry = False
|
||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||
current_partial_result = None
|
||||
@@ -638,15 +645,17 @@ class Agent:
|
||||
client: httpx.Client,
|
||||
prompt: str,
|
||||
system_prompt: None | str,
|
||||
json_format,
|
||||
count: PromptsCounter,
|
||||
pre_send_handler,
|
||||
result_handler,
|
||||
error_result_handler,
|
||||
error_result_handler
|
||||
) -> Any:
|
||||
result = self.send(
|
||||
client,
|
||||
prompt,
|
||||
system_prompt,
|
||||
json_format=json_format,
|
||||
pre_send_handler=pre_send_handler,
|
||||
result_handler=result_handler,
|
||||
error_result_handler=error_result_handler,
|
||||
@@ -658,6 +667,7 @@ class Agent:
|
||||
self,
|
||||
prompts: list[str],
|
||||
system_prompt: str | None = None,
|
||||
json_format=False,
|
||||
pre_send_handler: PreSendHandlerType = None,
|
||||
result_handler: ResultHandlerType = None,
|
||||
error_result_handler: ErrorResultHandlerType = None,
|
||||
@@ -680,6 +690,7 @@ class Agent:
|
||||
counter = PromptsCounter(len(prompts), self.logger)
|
||||
|
||||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||
json_formats = itertools.repeat(json_format, len(prompts))
|
||||
counters = itertools.repeat(counter, len(prompts))
|
||||
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
|
||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||
@@ -699,6 +710,7 @@ class Agent:
|
||||
clients,
|
||||
prompts,
|
||||
system_prompts,
|
||||
json_formats,
|
||||
counters,
|
||||
pre_send_handlers,
|
||||
result_handlers,
|
||||
|
||||
@@ -57,6 +57,7 @@ Output(target language: {to_lang}):
|
||||
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
|
||||
"""
|
||||
|
||||
|
||||
def get_original_segments(prompt: str):
|
||||
match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
|
||||
if match:
|
||||
@@ -64,6 +65,7 @@ def get_original_segments(prompt:str):
|
||||
else:
|
||||
raise ValueError("无法从prompt中提取初始文本")
|
||||
|
||||
|
||||
def get_target_segments(result: str):
|
||||
match = re.search(r'```json(.*)```', result, re.DOTALL)
|
||||
if match:
|
||||
@@ -71,17 +73,20 @@ def get_target_segments(result:str):
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentsTranslateAgentConfig(AgentConfig):
|
||||
to_lang: str
|
||||
custom_prompt: str | None = None
|
||||
glossary_dict: dict[str, str] | None = None
|
||||
json_format:bool = True
|
||||
|
||||
|
||||
class SegmentsTranslateAgent(Agent):
|
||||
def __init__(self, config: SegmentsTranslateAgentConfig):
|
||||
super().__init__(config)
|
||||
self.to_lang = config.to_lang
|
||||
self.json_format = config.json_format
|
||||
self.system_prompt = f"""
|
||||
# Role
|
||||
- You are a professional, authentic machine translation engine.
|
||||
@@ -178,7 +183,8 @@ class SegmentsTranslateAgent(Agent):
|
||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
||||
|
||||
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||
translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format,
|
||||
pre_send_handler=self._pre_send_handler,
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
|
||||
@@ -216,7 +222,8 @@ class SegmentsTranslateAgent(Agent):
|
||||
chunk_size)
|
||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
||||
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
||||
translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format,
|
||||
pre_send_handler=self._pre_send_handler,
|
||||
result_handler=self._result_handler,
|
||||
error_result_handler=self._error_result_handler)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
|
||||
},
|
||||
),
|
||||
"siliconflow": ("enable_thinking", True, False),
|
||||
"default":("reasoning_effort","medium","minimal"),
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +46,7 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
|
||||
return thinking_mode["volces"]
|
||||
elif "gemini" in model_id:
|
||||
return thinking_mode["google"]
|
||||
return None
|
||||
return thinking_mode["default"]
|
||||
|
||||
|
||||
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
|
||||
@@ -62,7 +63,7 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
|
||||
return thinking_mode["siliconflow"]
|
||||
elif provider == "api.302.ai":
|
||||
return get_thinking_mode_by_model_id(model_id)
|
||||
return None
|
||||
return thinking_mode["default"]
|
||||
|
||||
|
||||
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):
|
||||
|
||||
Reference in New Issue
Block a user