允许自定义系统提示

This commit is contained in:
xunbu
2025-05-28 10:16:42 +08:00
parent 0d29b636da
commit 3c9b94150c
7 changed files with 137 additions and 104 deletions

View File

@@ -76,6 +76,12 @@ class FileTranslater:
}
return result
def default_refine_agent(self, custom_prompt=None) -> MDRefineAgent:
return MDRefineAgent(custom_prompt=custom_prompt, **self._default_agent_params())
def default_translate_agent(self, custom_prompt=None, to_lang="中文") -> MDTranslateAgent:
return MDTranslateAgent(custom_prompt=custom_prompt, to_lang=to_lang, **self._default_agent_params())
def _convert2markdown(self, document: Document, formula: bool, code: bool, artifact: Path = None) -> str:
translater_logger.info(f"正在使用{self.convert_engin}转换文件为markdown")
if self.convert_engin == "docling":
@@ -202,48 +208,49 @@ class FileTranslater:
self.save_as_markdown(filename=f"{file_path.stem}.md")
return self
def refine_markdown_by_agent(self, refine_agent: Agent | None = None) -> str:
def refine_markdown_by_agent(self, refine_agent: Agent | None = None, custom_prompt=None) -> str:
translater_logger.info("正在修正markdown")
self._mask_uris_in_markdown()
chuncks = self._split_markdown_into_chunks()
if refine_agent is None:
refine_agent = MDRefineAgent(**self._default_agent_params())
refine_agent = self.default_refine_agent(custom_prompt)
result: list[str] = refine_agent.send_prompts(chuncks)
self.markdown = join_markdown_texts(result)
self._unmask_uris_in_markdown()
translater_logger.info("markdown已修正")
return self.markdown
def translate_markdown_by_agent(self, translate_agent: Agent | None = None, to_lang="中文"):
def translate_markdown_by_agent(self, translate_agent: Agent | None = None, to_lang="中文", custom_prompt=None):
translater_logger.info("正在翻译markdown")
self._mask_uris_in_markdown()
chuncks = self._split_markdown_into_chunks()
if translate_agent is None:
translate_agent = MDTranslateAgent(to_lang=to_lang, **self._default_agent_params())
translate_agent = self.default_translate_agent(custom_prompt=custom_prompt, to_lang=to_lang)
result: list[str] = translate_agent.send_prompts(chuncks)
self.markdown = join_markdown_texts(result)
self._unmask_uris_in_markdown()
translater_logger.info("翻译完成")
return self.markdown
async def refine_markdown_by_agent_async(self, refine_agent: Agent | None = None) -> str:
async def refine_markdown_by_agent_async(self, refine_agent: Agent | None = None, custom_prompt=None) -> str:
translater_logger.info("正在修正markdown")
self._mask_uris_in_markdown()
chuncks = self._split_markdown_into_chunks()
if refine_agent is None:
refine_agent = MDRefineAgent(**self._default_agent_params())
refine_agent = self.default_refine_agent(custom_prompt=custom_prompt)
result: list[str] = await refine_agent.send_prompts_async(chuncks)
self.markdown = join_markdown_texts(result)
self._unmask_uris_in_markdown()
translater_logger.info("markdown已修正")
return self.markdown
async def translate_markdown_by_agent_async(self, translate_agent: Agent | None = None, to_lang="中文"):
async def translate_markdown_by_agent_async(self, translate_agent: Agent | None = None, to_lang="中文",
custom_prompt=None):
translater_logger.info("正在翻译markdown")
self._mask_uris_in_markdown()
chuncks = self._split_markdown_into_chunks()
if translate_agent is None:
translate_agent = MDTranslateAgent(to_lang=to_lang, **self._default_agent_params())
translate_agent = self.default_translate_agent(to_lang=to_lang, custom_prompt=custom_prompt)
result: list[str] = await translate_agent.send_prompts_async(chuncks)
self.markdown = join_markdown_texts(result)
self._unmask_uris_in_markdown()
@@ -318,7 +325,9 @@ class FileTranslater:
def translate_file(self, file_path: Path | str | None = None, to_lang="中文", output_dir="./output",
formula=True,
code=True, output_format: Literal["markdown", "html"] = "markdown", refine=False,
refine_agent: Agent | None = None, translate_agent: Agent | None = None, save=True):
custom_prompt_translate=None, refine_agent: Agent | None = None,
translate_agent: Agent | None = None,
save=True):
if file_path is None:
assert self.file_path is not None, "未输入文件路径"
file_path = self.file_path
@@ -327,7 +336,7 @@ class FileTranslater:
self.read_file(file_path, formula=formula, code=code)
if refine:
self.refine_markdown_by_agent(refine_agent)
self.translate_markdown_by_agent(translate_agent, to_lang=to_lang)
self.translate_markdown_by_agent(translate_agent, to_lang=to_lang, custom_prompt=custom_prompt_translate)
if save:
if output_format == "markdown":
filename = f"{file_path.stem}_{to_lang}.md"
@@ -339,7 +348,8 @@ class FileTranslater:
async def translate_file_async(self, file_path: Path | str | None = None, to_lang="中文", output_dir="./output",
formula=True,
code=True, output_format: Literal["markdown", "html"] = "markdown", refine=False,
code=True, output_format: Literal["markdown", "html"] = "markdown",
custom_prompt_translate=None, refine=False,
refine_agent: Agent | None = None, translate_agent: Agent | None = None, save=True):
if file_path is None:
assert self.file_path is not None, "未输入文件路径"
@@ -354,7 +364,8 @@ class FileTranslater:
)
if refine:
await self.refine_markdown_by_agent_async(refine_agent)
await self.translate_markdown_by_agent_async(translate_agent, to_lang=to_lang)
await self.translate_markdown_by_agent_async(translate_agent, to_lang=to_lang,
custom_prompt=custom_prompt_translate)
if save:
if output_format == "markdown":
filename = f"{file_path.stem}_{to_lang}.md"
@@ -366,12 +377,14 @@ class FileTranslater:
def translate_bytes(self, name: str, file: bytes, to_lang="中文", output_dir="./output",
formula=True,
code=True, output_format: Literal["markdown", "html"] = "markdown", refine=False,
code=True, output_format: Literal["markdown", "html"] = "markdown",
custom_prompt_translate=None,
refine=False,
refine_agent: Agent | None = None, translate_agent: Agent | None = None, save=True):
self.read_bytes(name=name, file=file, formula=formula, code=code)
if refine:
self.refine_markdown_by_agent(refine_agent)
self.translate_markdown_by_agent(translate_agent, to_lang=to_lang)
self.translate_markdown_by_agent(translate_agent, to_lang=to_lang, custom_prompt=custom_prompt_translate)
if save:
if output_format == "markdown":
filename = f"{name}_{to_lang}.md"
@@ -383,13 +396,15 @@ class FileTranslater:
async def translate_bytes_async(self, name: str, file: bytes, to_lang="中文", output_dir="./output",
formula=True,
code=True, output_format: Literal["markdown", "html"] = "markdown", refine=False,
code=True, output_format: Literal["markdown", "html"] = "markdown",
custom_prompt_translate=None, refine=False,
refine_agent: Agent | None = None, translate_agent: Agent | None = None, save=True):
await self.read_bytes_async(name=name, file=file, formula=formula, code=code)
if refine:
await self.refine_markdown_by_agent_async(refine_agent)
await self.translate_markdown_by_agent_async(translate_agent, to_lang=to_lang)
await self.translate_markdown_by_agent_async(translate_agent, to_lang=to_lang,
custom_prompt=custom_prompt_translate)
if save:
if output_format == "markdown":
filename = f"{name}_{to_lang}.md"