修复markdown出错时不返回原文的问题,优化提示词

This commit is contained in:
xunbu
2025-10-18 21:25:32 +08:00
parent 0700dbf58a
commit b31017bfb7
5 changed files with 33 additions and 16 deletions

View File

@@ -51,7 +51,7 @@ Output:
def get_original_segments(prompt: str): def get_original_segments(prompt: str):
match = re.search(r'<input>(.*)</input>', prompt, re.DOTALL) match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
if match: if match:
return match.group(1) return match.group(1)
else: else:

View File

@@ -1,11 +1,20 @@
# SPDX-FileCopyrightText: 2025 QinHan # SPDX-FileCopyrightText: 2025 QinHan
# SPDX-License-Identifier: MPL-2.0 # SPDX-License-Identifier: MPL-2.0
import re
from dataclasses import dataclass from dataclasses import dataclass
from .agent import Agent, AgentConfig from .agent import Agent, AgentConfig
from ..glossary.glossary import Glossary from ..glossary.glossary import Glossary
def get_original_markdown(prompt: str):
match = re.search(r'<input>\n(.*)\n</input>', prompt, re.DOTALL)
if match:
return match.group(1)
else:
raise ValueError("无法从prompt中提取初始文本")
def generate_prompt(markdown_text: str, to_lang: str): def generate_prompt(markdown_text: str, to_lang: str):
return f""" return f"""
Treat the text input as markdown text and translate it into {to_lang},output translation ONLY. Treat the text input as markdown text and translate it into {to_lang},output translation ONLY.
@@ -20,9 +29,12 @@ Treat the text input as markdown text and translate it into {to_lang},output tra
- Output the translated markdown text as plain text (not in a markdown code block, with no extraneous text). - Output the translated markdown text as plain text (not in a markdown code block, with no extraneous text).
The markdown text input: The markdown text input:
<input>
{markdown_text} {markdown_text}
</input>
""" """
@dataclass @dataclass
class MDTranslateAgentConfig(AgentConfig): class MDTranslateAgentConfig(AgentConfig):
to_lang: str to_lang: str
@@ -33,7 +45,7 @@ class MDTranslateAgentConfig(AgentConfig):
class MDTranslateAgent(Agent): class MDTranslateAgent(Agent):
def __init__(self, config: MDTranslateAgentConfig): def __init__(self, config: MDTranslateAgentConfig):
super().__init__(config) super().__init__(config)
self.to_lang=config.to_lang self.to_lang = config.to_lang
self.system_prompt = f""" self.system_prompt = f"""
# Role # Role
You are a professional machine translation engine. You are a professional machine translation engine.
@@ -50,12 +62,15 @@ You are a professional machine translation engine.
return system_prompt, prompt return system_prompt, prompt
def send_chunks(self, prompts: list[str]): def send_chunks(self, prompts: list[str]):
prompts=[generate_prompt(prompt,self.to_lang) for prompt in prompts] prompts = [generate_prompt(prompt, self.to_lang) for prompt in prompts]
return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler) return super().send_prompts(prompts=prompts, pre_send_handler=self._pre_send_handler,
error_result_handler=lambda prompt, logger: get_original_markdown(prompt))
async def send_chunks_async(self, prompts: list[str]): async def send_chunks_async(self, prompts: list[str]):
prompts = [generate_prompt(prompt, self.to_lang) for prompt in prompts] prompts = [generate_prompt(prompt, self.to_lang) for prompt in prompts]
return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler) return await super().send_prompts_async(prompts=prompts, pre_send_handler=self._pre_send_handler,
error_result_handler=lambda prompt, logger: get_original_markdown(
prompt))
def update_glossary_dict(self, update_dict: dict | None): def update_glossary_dict(self, update_dict: dict | None):
if self.glossary_dict is None: if self.glossary_dict is None:

View File

@@ -22,11 +22,9 @@ You will receive a sequence of original text segments to be translated, represen
Here is the input: Here is the input:
<input> <input>
```json ```json
{json_segments} {json_segments}
``` ```
</input> </input>
For each Key-Value Pair in the JSON, translate the contents of the value into {to_lang}, Write the translation back into the value for that JSON. For each Key-Value Pair in the JSON, translate the contents of the value into {to_lang}, Write the translation back into the value for that JSON.
@@ -60,7 +58,7 @@ Please return the translated JSON directly without including any additional info
""" """
def get_original_segments(prompt:str): def get_original_segments(prompt:str):
match = re.search(r'<input>(.*)</input>', prompt, re.DOTALL) match = re.search(r'<input>\n```json\n(.*)\n```\n</input>', prompt, re.DOTALL)
if match: if match:
return match.group(1) return match.group(1)
else: else:

View File

@@ -16,25 +16,29 @@ class MDBasedCovertCacher:
self.cache_dict = OrderedDict() self.cache_dict = OrderedDict()
@staticmethod @staticmethod
def _get_hashcode(document: Document, convert_engin: str, convert_config: ConverterConfig|None) -> str: def _get_hashcode(document: Document, convert_engin: str, convert_config: ConverterConfig | None) -> str:
if convert_config : if convert_config:
convert_config_hash=convert_config.gethash() convert_config_hash = convert_config.gethash()
else: else:
convert_config_hash=None convert_config_hash = None
obj = (document.suffix, document.content, convert_engin, convert_config_hash) obj = (document.suffix, document.content, convert_engin, convert_config_hash)
return str(hash(obj)) return str(hash(obj))
def get_cached_result(self, document: Document, convert_engin: str, def get_cached_result(self, document: Document, convert_engin: str,
convert_config: ConverterConfig) -> MarkdownDocument | None: convert_config: ConverterConfig) -> MarkdownDocument | None:
return self.cache_dict.get(self._get_hashcode(document, convert_engin, convert_config)) d: MarkdownDocument | None = self.cache_dict.get(self._get_hashcode(document, convert_engin, convert_config))
if d:
return d.copy()
else:
return None
def cache_result(self, convert_result: MarkdownDocument, document: Document, convert_engin: str, def cache_result(self, convert_result: MarkdownDocument, document: Document, convert_engin: str,
convert_config: ConverterConfig) -> MarkdownDocument: convert_config: ConverterConfig) -> MarkdownDocument:
hash_code = self._get_hashcode(document, convert_engin, convert_config) hash_code = self._get_hashcode(document, convert_engin, convert_config)
if len(self.cache_dict) > int(CACHE_NUM): if len(self.cache_dict) > int(CACHE_NUM):
self.cache_dict.popitem(last=False) self.cache_dict.popitem(last=False)
self.cache_dict[hash_code] = convert_result self.cache_dict[hash_code] = convert_result.copy()
return convert_result return convert_result
def clear(self): def clear(self):

View File

@@ -83,7 +83,7 @@ class MarkdownBasedWorkflow(Workflow[MarkdownBasedWorkflowConfig, Document, Mark
for attachment in converter.attachments: for attachment in converter.attachments:
self.attachment.add_attachment(attachment) self.attachment.add_attachment(attachment)
# 缓存解析后文件 # 缓存解析后文件
md_based_convert_cacher.cache_result(document_md.copy(), self.document_original, convert_engin, convert_config) md_based_convert_cacher.cache_result(document_md, self.document_original, convert_engin, convert_config)
return document_md return document_md