优化segments_agent提示词

This commit is contained in:
xunbu
2025-10-16 20:50:18 +08:00
parent 14afb0eb6d
commit 324ad77a2e
2 changed files with 71 additions and 54 deletions

View File

@@ -292,7 +292,7 @@ class Agent:
response.raise_for_status()
# print(f"【测试】resp:\n{response.json()}")
result = response.json()["choices"][0]["message"]["content"]
# print(f"【测试】\nprompt:\n{prompt}\nresp:\n{result}")
# 获取token使用情况
response_data = response.json()
input_tokens, cached_tokens, output_tokens, reasoning_tokens = (

View File

@@ -3,6 +3,7 @@
import asyncio
import json
import re
from dataclasses import dataclass
from json import JSONDecodeError
from logging import Logger
@@ -15,6 +16,62 @@ from docutranslate.glossary.glossary import Glossary
from docutranslate.utils.json_utils import segments2json_chunks, fix_json_string
def generate_prompt(json_segments: str, to_lang: str):
return f"""
You will receive a sequence of original text segments to be translated, represented in JSON format. The keys are segment IDs, and the values are the text content to be translated.
Here is the input:
<input>
```json
{json_segments}
```
</input>
For each Key-Value Pair in the JSON, translate the contents of the value into {to_lang}, Write the translation back into the value for that JSON.
> (Very important) The original text segments and translated segments must strictly correspond one-to-one. It is strictly forbidden for the IDs of the translated segments to differ from those of the original segments.
> The segment IDs in the output must exactly match those in the input. And all segment IDs in input must appear in the output.
Here is an example of the expected format:
<example>
Input:
```json
{{
3:source,
4:source,
}}
```
Output(target language: {to_lang}):
```json
{{
3:translation,
4:translation,
}}
```
</example>
Please return the translated JSON directly without including any additional information and preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
"""
def get_original_segments(prompt:str):
match = re.search(r'<input>(.*)</input>', prompt, re.DOTALL)
if match:
return match.group(1)
else:
raise ValueError("无法从prompt中提取初始文本")
def get_target_segments(result:str):
match = re.search(r'```json(.*)```', result, re.DOTALL)
if match:
return match.group(1)
else:
return result
@dataclass
class SegmentsTranslateAgentConfig(AgentConfig):
to_lang: str
@@ -25,53 +82,10 @@ class SegmentsTranslateAgentConfig(AgentConfig):
class SegmentsTranslateAgent(Agent):
def __init__(self, config: SegmentsTranslateAgentConfig):
super().__init__(config)
self.to_lang = config.to_lang
self.system_prompt = f"""
# Role
- You are a text segment translation engine that needs to translate received original text segments into target language text segments.
# Task
- You will receive a sequence of original text segments to be translated, represented in JSON format. The keys are segment IDs, and the values are the text content to be translated.
- You need to translate these text segments into the target language.
- Target language: {config.to_lang}
# Requirements
- Translations must be professional and accurate.
- Do not output any explanations or comments but only the {config.to_lang} translations.
- Use the most common translations for personal names and proper nouns.
- Preserve special tags or untranslatable elements (such as code, brand names, technical terms) as they are.
- (Very important) The original text segments and translated segments must strictly correspond one-to-one. It is strictly forbidden for the IDs of the translated segments to differ from those of the original segments.
# Input Specification
{{
"<segment ID>": "<text to be translated>"
}}
# Output Specification
{{
"<segment ID>": "<translated text>"
}}
- The response must be a **valid** JSON object
- Escape the double quotes within the JSON string.
- (very important) The segment IDs in the output must exactly match those in the input. And all segment IDs in input must appear in the output.
# Example (assuming the target language in this example is English, {config.to_lang} is the actual target language)
## Input
{{
"8": "然后呢?我们",
"9": "就可以看到这个界面了",
"10": "乔布斯在上海吃泡面",
"11": "汤姆说:“你好”"
}}
## Correct Output
{{
"8": "And then? We",
"9": "can then see this interface",
"10": "Steve Jobs ate instant noodles in Shanghai.",
"11": "Tom says:\\\"hello\\\""
}}
- You are a professional, authentic machine translation engine.
"""
self.custom_prompt = config.custom_prompt
if config.custom_prompt:
@@ -91,13 +105,15 @@ class SegmentsTranslateAgent(Agent):
- 如果键不匹配,构造一个部分成功的结果,并通过 PartialTranslationError 异常抛出,以触发重试。
- 其他错误如JSON解析失败、模型偷懒则抛出普通 ValueError 触发重试。
"""
original_segments=get_original_segments(origin_prompt)
result = get_target_segments(result)
if result == "":
if origin_prompt.strip() != "":
if original_segments.strip() != "":
raise AgentResultError("result为空值但原文不为空")
return {}
try:
result = fix_json_string(result)
original_chunk = json.loads(origin_prompt)
original_chunk = json_repair.loads(original_segments)
repaired_result = json_repair.loads(result)
if not isinstance(repaired_result, dict):
@@ -144,22 +160,23 @@ class SegmentsTranslateAgent(Agent):
处理在所有重试后仍然失败的请求。
作为备用方案,返回原文内容,并将所有值转换为字符串。
"""
if origin_prompt == "":
original_segments=get_original_segments(origin_prompt)
if original_segments == "":
return {}
try:
original_chunk = json.loads(origin_prompt)
original_chunk = json_repair.loads(original_segments)
# 此处逻辑保留,作为最终的兜底方案
for key, value in original_chunk.items():
original_chunk[key] = f"{value}"
return original_chunk
except (RuntimeError, JSONDecodeError):
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}")
logger.error(f"原始prompt也不是有效的json格式: {original_segments}")
# 如果原始prompt本身也无效返回一个清晰的错误对象
return {"error": f"{origin_prompt}"}
return {"error": f"{original_segments}"}
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False, indent=0) for chunk in chunks]
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,
@@ -197,7 +214,7 @@ class SegmentsTranslateAgent(Agent):
async def send_segments_async(self, segments: list[str], chunk_size: int) -> list[str]:
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
chunk_size)
prompts = [json.dumps(chunk, ensure_ascii=False, indent=0) for chunk in chunks]
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
translated_chunks = await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
result_handler=self._result_handler,