提供json_format选项
This commit is contained in:
@@ -234,8 +234,9 @@ class Agent:
|
|||||||
elif self.thinking == "disable":
|
elif self.thinking == "disable":
|
||||||
data[field_thinking] = val_disable
|
data[field_thinking] = val_disable
|
||||||
|
|
||||||
|
|
||||||
def _prepare_request_data(
|
def _prepare_request_data(
|
||||||
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
|
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9,json_format=False
|
||||||
):
|
):
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = self.temperature
|
temperature = self.temperature
|
||||||
@@ -254,6 +255,8 @@ class Agent:
|
|||||||
}
|
}
|
||||||
if self.thinking != "default":
|
if self.thinking != "default":
|
||||||
self._add_thinking_mode(data)
|
self._add_thinking_mode(data)
|
||||||
|
if json_format:
|
||||||
|
data["response_format"] = {"type": "json_object"}
|
||||||
return headers, data
|
return headers, data
|
||||||
|
|
||||||
async def send_async(
|
async def send_async(
|
||||||
@@ -263,6 +266,7 @@ class Agent:
|
|||||||
system_prompt: None | str = None,
|
system_prompt: None | str = None,
|
||||||
retry=True,
|
retry=True,
|
||||||
retry_count=0,
|
retry_count=0,
|
||||||
|
json_format=False,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None,
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
@@ -274,7 +278,7 @@ class Agent:
|
|||||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||||
# print(f"system_prompt:\n{system_prompt}")
|
# print(f"system_prompt:\n{system_prompt}")
|
||||||
# print(f"【测试】prompt:\n{prompt}")
|
# print(f"【测试】prompt:\n{prompt}")
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
|
||||||
should_retry = False
|
should_retry = False
|
||||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||||
current_partial_result = None
|
current_partial_result = None
|
||||||
@@ -412,6 +416,7 @@ class Agent:
|
|||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_concurrent: int | None = None,
|
max_concurrent: int | None = None,
|
||||||
|
json_format=False,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None,
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
@@ -454,6 +459,7 @@ class Agent:
|
|||||||
client=client,
|
client=client,
|
||||||
prompt=p_text,
|
prompt=p_text,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
json_format=json_format,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
@@ -494,6 +500,7 @@ class Agent:
|
|||||||
system_prompt: None | str = None,
|
system_prompt: None | str = None,
|
||||||
retry=True,
|
retry=True,
|
||||||
retry_count=0,
|
retry_count=0,
|
||||||
|
json_format=False,
|
||||||
pre_send_handler=None,
|
pre_send_handler=None,
|
||||||
result_handler=None,
|
result_handler=None,
|
||||||
error_result_handler=None,
|
error_result_handler=None,
|
||||||
@@ -504,7 +511,7 @@ class Agent:
|
|||||||
if pre_send_handler:
|
if pre_send_handler:
|
||||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||||
|
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt)
|
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
|
||||||
should_retry = False
|
should_retry = False
|
||||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||||
current_partial_result = None
|
current_partial_result = None
|
||||||
@@ -638,15 +645,17 @@ class Agent:
|
|||||||
client: httpx.Client,
|
client: httpx.Client,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system_prompt: None | str,
|
system_prompt: None | str,
|
||||||
|
json_format,
|
||||||
count: PromptsCounter,
|
count: PromptsCounter,
|
||||||
pre_send_handler,
|
pre_send_handler,
|
||||||
result_handler,
|
result_handler,
|
||||||
error_result_handler,
|
error_result_handler
|
||||||
) -> Any:
|
) -> Any:
|
||||||
result = self.send(
|
result = self.send(
|
||||||
client,
|
client,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
|
json_format=json_format,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
@@ -658,6 +667,7 @@ class Agent:
|
|||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
|
json_format=False,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None,
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
@@ -680,6 +690,7 @@ class Agent:
|
|||||||
counter = PromptsCounter(len(prompts), self.logger)
|
counter = PromptsCounter(len(prompts), self.logger)
|
||||||
|
|
||||||
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
system_prompts = itertools.repeat(system_prompt, len(prompts))
|
||||||
|
json_formats = itertools.repeat(json_format, len(prompts))
|
||||||
counters = itertools.repeat(counter, len(prompts))
|
counters = itertools.repeat(counter, len(prompts))
|
||||||
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
|
pre_send_handlers = itertools.repeat(pre_send_handler, len(prompts))
|
||||||
result_handlers = itertools.repeat(result_handler, len(prompts))
|
result_handlers = itertools.repeat(result_handler, len(prompts))
|
||||||
@@ -699,6 +710,7 @@ class Agent:
|
|||||||
clients,
|
clients,
|
||||||
prompts,
|
prompts,
|
||||||
system_prompts,
|
system_prompts,
|
||||||
|
json_formats,
|
||||||
counters,
|
counters,
|
||||||
pre_send_handlers,
|
pre_send_handlers,
|
||||||
result_handlers,
|
result_handlers,
|
||||||
|
|||||||
@@ -57,31 +57,36 @@ Output(target language: {to_lang}):
|
|||||||
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
|
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_original_segments(prompt:str):
|
|
||||||
|
def get_original_segments(prompt: str):
|
||||||
match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
|
match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
else:
|
else:
|
||||||
raise ValueError("无法从prompt中提取初始文本")
|
raise ValueError("无法从prompt中提取初始文本")
|
||||||
|
|
||||||
def get_target_segments(result:str):
|
|
||||||
|
def get_target_segments(result: str):
|
||||||
match = re.search(r'```json(.*)```', result, re.DOTALL)
|
match = re.search(r'```json(.*)```', result, re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SegmentsTranslateAgentConfig(AgentConfig):
|
class SegmentsTranslateAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: str | None = None
|
custom_prompt: str | None = None
|
||||||
glossary_dict: dict[str, str] | None = None
|
glossary_dict: dict[str, str] | None = None
|
||||||
|
json_format:bool = True
|
||||||
|
|
||||||
|
|
||||||
class SegmentsTranslateAgent(Agent):
|
class SegmentsTranslateAgent(Agent):
|
||||||
def __init__(self, config: SegmentsTranslateAgentConfig):
|
def __init__(self, config: SegmentsTranslateAgentConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.to_lang = config.to_lang
|
self.to_lang = config.to_lang
|
||||||
|
self.json_format = config.json_format
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
# Role
|
# Role
|
||||||
- You are a professional, authentic machine translation engine.
|
- You are a professional, authentic machine translation engine.
|
||||||
@@ -104,7 +109,7 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
|
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
|
||||||
- 其他错误(如JSON解析失败、模型偷懒)则抛出普通 ValueError 触发重试。
|
- 其他错误(如JSON解析失败、模型偷懒)则抛出普通 ValueError 触发重试。
|
||||||
"""
|
"""
|
||||||
original_segments=get_original_segments(origin_prompt)
|
original_segments = get_original_segments(origin_prompt)
|
||||||
result = get_target_segments(result)
|
result = get_target_segments(result)
|
||||||
if result == "":
|
if result == "":
|
||||||
if original_segments.strip() != "":
|
if original_segments.strip() != "":
|
||||||
@@ -160,7 +165,7 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
处理在所有重试后仍然失败的请求。
|
处理在所有重试后仍然失败的请求。
|
||||||
作为备用方案,返回原文内容,并将所有值转换为字符串。
|
作为备用方案,返回原文内容,并将所有值转换为字符串。
|
||||||
"""
|
"""
|
||||||
original_segments=get_original_segments(origin_prompt)
|
original_segments = get_original_segments(origin_prompt)
|
||||||
if original_segments == "":
|
if original_segments == "":
|
||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
@@ -178,7 +183,8 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
||||||
|
|
||||||
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format,
|
||||||
|
pre_send_handler=self._pre_send_handler,
|
||||||
result_handler=self._result_handler,
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
|
|
||||||
@@ -216,7 +222,8 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
chunk_size)
|
chunk_size)
|
||||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
||||||
|
|
||||||
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
|
translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format,
|
||||||
|
pre_send_handler=self._pre_send_handler,
|
||||||
result_handler=self._result_handler,
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ thinking_mode:dict[mode_type,tuple[thinking_field, enable_value, disable_value]]
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
"siliconflow": ("enable_thinking", True, False),
|
"siliconflow": ("enable_thinking", True, False),
|
||||||
|
"default":("reasoning_effort","medium","minimal"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ def get_thinking_mode_by_model_id(model_id: str) -> tuple[str, str | dict, str |
|
|||||||
return thinking_mode["volces"]
|
return thinking_mode["volces"]
|
||||||
elif "gemini" in model_id:
|
elif "gemini" in model_id:
|
||||||
return thinking_mode["google"]
|
return thinking_mode["google"]
|
||||||
return None
|
return thinking_mode["default"]
|
||||||
|
|
||||||
|
|
||||||
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
|
def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, str | dict] | None:
|
||||||
@@ -62,7 +63,7 @@ def get_thinking_mode(provider: str, model_id: str) -> tuple[str, str | dict, st
|
|||||||
return thinking_mode["siliconflow"]
|
return thinking_mode["siliconflow"]
|
||||||
elif provider == "api.302.ai":
|
elif provider == "api.302.ai":
|
||||||
return get_thinking_mode_by_model_id(model_id)
|
return get_thinking_mode_by_model_id(model_id)
|
||||||
return None
|
return thinking_mode["default"]
|
||||||
|
|
||||||
|
|
||||||
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):
|
# def add_thinking_mode(data: dict, provider: str, model_id: str, think_enable: bool):
|
||||||
|
|||||||
Reference in New Issue
Block a user