优化segments_agent提示词

This commit is contained in:
xunbu
2025-10-16 20:50:18 +08:00
parent 14afb0eb6d
commit 324ad77a2e
2 changed files with 71 additions and 54 deletions

View File

@@ -292,7 +292,7 @@ class Agent:
response.raise_for_status() response.raise_for_status()
# print(f"【测试】resp:\n{response.json()}") # print(f"【测试】resp:\n{response.json()}")
result = response.json()["choices"][0]["message"]["content"] result = response.json()["choices"][0]["message"]["content"]
# print(f"【测试】\nprompt:\n{prompt}\nresp:\n{result}")
# 获取token使用情况 # 获取token使用情况
response_data = response.json() response_data = response.json()
input_tokens, cached_tokens, output_tokens, reasoning_tokens = ( input_tokens, cached_tokens, output_tokens, reasoning_tokens = (

View File

@@ -3,6 +3,7 @@
import asyncio import asyncio
import json import json
import re
from dataclasses import dataclass from dataclasses import dataclass
from json import JSONDecodeError from json import JSONDecodeError
from logging import Logger from logging import Logger
@@ -15,6 +16,62 @@ from docutranslate.glossary.glossary import Glossary
from docutranslate.utils.json_utils import segments2json_chunks, fix_json_string from docutranslate.utils.json_utils import segments2json_chunks, fix_json_string
def generate_prompt(json_segments: str, to_lang: str):
return f"""
You will receive a sequence of original text segments to be translated, represented in JSON format. The keys are segment IDs, and the values are the text content to be translated.
Here is the input:
<input>
```json
{json_segments}
```
</input>
For each Key-Value Pair in the JSON, translate the contents of the value into {to_lang}, Write the translation back into the value for that JSON.
> (Very important) The original text segments and translated segments must strictly correspond one-to-one. It is strictly forbidden for the IDs of the translated segments to differ from those of the original segments.
> The segment IDs in the output must exactly match those in the input. And all segment IDs in input must appear in the output.
Here is an example of the expected format:
<example>
Input:
```json
{{
3:source,
4:source,
}}
```
Output(target language: {to_lang}):
```json
{{
3:translation,
4:translation,
}}
```
</example>
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
"""
def get_original_segments(prompt:str):
match = re.search(r'<input>(.*)</input>', prompt, re.DOTALL)
if match:
return match.group(1)
else:
raise ValueError("无法从prompt中提取初始文本")
def get_target_segments(result:str):
match = re.search(r'```json(.*)```', result, re.DOTALL)
if match:
return match.group(1)
else:
return result
@dataclass @dataclass
class SegmentsTranslateAgentConfig(AgentConfig): class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str to_lang: str
@@ -25,53 +82,10 @@ class SegmentsTranslateAgentConfig(AgentConfig):
class SegmentsTranslateAgent(Agent): class SegmentsTranslateAgent(Agent):
def __init__(self, config: SegmentsTranslateAgentConfig): def __init__(self, config: SegmentsTranslateAgentConfig):
super().__init__(config) super().__init__(config)
self.to_lang = config.to_lang
self.system_prompt = f""" self.system_prompt = f"""
# Role # Role
- You are a text segment translation engine that needs to translate received original text segments into target language text segments. - You are a professional, authentic machine translation engine.
# Task
- You will receive a sequence of original text segments to be translated, represented in JSON format. The keys are segment IDs, and the values are the text content to be translated.
- You need to translate these text segments into the target language.
- Target language: {config.to_lang}
# Requirements
- Translations must be professional and accurate.
- Do not output any explanations or comments but only the {config.to_lang} translations.
- Use the most common translations for personal names and proper nouns.
- Preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
- (Very important) The original text segments and translated segments must strictly correspond one-to-one. It is strictly forbidden for the IDs of the translated segments to differ from those of the original segments.
# Input Specification
{{
"<segment ID>": "<text to be translated>"
}}
# Output Specification
{{
"<segment ID>": "<translated text>"
}}
- The response must be a **valid** JSON object
- Escape the double quotes within the JSON string.
- (very important) The segment IDs in the output must exactly match those in the input. And all segment IDs in input must appear in the output.
# Example (assuming the target language in this example is English, {config.to_lang} is the actual target language)
## Input
{{
"8": "然后呢?我们",
"9": "就可以看到这个界面了",
"10": "乔布斯在上海吃泡面",
"11": "汤姆说:“你好”"
}}
## Correct Output
{{
"8": "And then? We",
"9": "can then see this interface",
"10": "Steve Jobs ate instant noodles in Shanghai.",
"11": "Tom says:\\\"hello\\\""
}}
""" """
self.custom_prompt = config.custom_prompt self.custom_prompt = config.custom_prompt
if config.custom_prompt: if config.custom_prompt:
@@ -91,13 +105,15 @@ class SegmentsTranslateAgent(Agent):
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。 - 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
- 其他错误如JSON解析失败、模型偷懒则抛出普通 ValueError 触发重试。 - 其他错误如JSON解析失败、模型偷懒则抛出普通 ValueError 触发重试。
""" """
original_segments=get_original_segments(origin_prompt)
result = get_target_segments(result)
if result == "": if result == "":
if origin_prompt.strip() != "": if original_segments.strip() != "":
raise AgentResultError("result为空值但原文不为空") raise AgentResultError("result为空值但原文不为空")
return {} return {}
try: try:
result = fix_json_string(result) result = fix_json_string(result)
original_chunk = json.loads(origin_prompt) original_chunk = json_repair.loads(original_segments)
repaired_result = json_repair.loads(result) repaired_result = json_repair.loads(result)
if not isinstance(repaired_result, dict): if not isinstance(repaired_result, dict):
@@ -144,22 +160,23 @@ class SegmentsTranslateAgent(Agent):
处理在所有重试后仍然失败的请求。 处理在所有重试后仍然失败的请求。
作为备用方案,返回原文内容,并将所有值转换为字符串。 作为备用方案,返回原文内容,并将所有值转换为字符串。
""" """
if origin_prompt == "": original_segments=get_original_segments(origin_prompt)
if original_segments == "":
return {} return {}
try: try:
original_chunk = json.loads(origin_prompt) original_chunk = json_repair.loads(original_segments)
# 此处逻辑保留,作为最终的兜底方案 # 此处逻辑保留,作为最终的兜底方案
for key, value in original_chunk.items(): for key, value in original_chunk.items():
original_chunk[key] = f"{value}" original_chunk[key] = f"{value}"
return original_chunk return original_chunk
except (RuntimeError, JSONDecodeError): except (RuntimeError, JSONDecodeError):
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}") logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
# 如果原始prompt本身也无效返回一个清晰的错误对象 # 如果原始prompt本身也无效返回一个清晰的错误对象
return {"error": f"{origin_prompt}"} return {"error": f"{original_segments}"}
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]: def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size) indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False, indent=0) for chunk in chunks] prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,
@@ -197,7 +214,7 @@ class SegmentsTranslateAgent(Agent):
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]: async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments, indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
chunk_size) chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False, indent=0) for chunk in chunks] prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler, translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler, result_handler=self._result_handler,