增加glossary prompt
This commit is contained in:
@@ -17,6 +17,7 @@ from docutranslate.utils.json_utils import segments2json_chunks
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GlossaryAgentConfig(AgentConfig):
|
class GlossaryAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
|
custom_prompt: str = None
|
||||||
|
|
||||||
|
|
||||||
class GlossaryAgent(Agent):
|
class GlossaryAgent(Agent):
|
||||||
@@ -49,10 +50,12 @@ The output format should be plain JSON text in a list format
|
|||||||
## Output
|
## Output
|
||||||
{r'[{"src": "Jobs", "dst": "乔布斯"}, {"src": "Bill Gates", "dst": "比尔盖茨"}, {"src": "Shanghai", "dst": "上海"}]'}
|
{r'[{"src": "Jobs", "dst": "乔布斯"}, {"src": "Bill Gates", "dst": "比尔盖茨"}, {"src": "Shanghai", "dst": "上海"}]'}
|
||||||
"""
|
"""
|
||||||
|
if config.custom_prompt:
|
||||||
|
self.system_prompt += "\n# **Important rules or background** \n" + self.custom_prompt + '\nEND\n'
|
||||||
|
|
||||||
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
def _result_handler(self, result: str, origin_prompt: str, logger: Logger):
|
||||||
if result == "":
|
if result == "":
|
||||||
if origin_prompt.strip()!="":
|
if origin_prompt.strip() != "":
|
||||||
logger.error("result为空值但原文不为空")
|
logger.error("result为空值但原文不为空")
|
||||||
raise AgentResultError("result为空值但原文不为空")
|
raise AgentResultError("result为空值但原文不为空")
|
||||||
return []
|
return []
|
||||||
@@ -72,7 +75,7 @@ The output format should be plain JSON text in a list format
|
|||||||
return json_repair.loads(origin_prompt)
|
return json_repair.loads(origin_prompt)
|
||||||
except (RuntimeError, JSONDecodeError):
|
except (RuntimeError, JSONDecodeError):
|
||||||
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}")
|
logger.error(f"原始prompt也不是有效的json格式: {origin_prompt}")
|
||||||
return [] # 如果原始prompt也无效,返回空列表
|
return [] # 如果原始prompt也无效,返回空列表
|
||||||
|
|
||||||
def send_segments(self, segments: list[str], chunk_size: int):
|
def send_segments(self, segments: list[str], chunk_size: int):
|
||||||
self.logger.info(f"开始提取术语表,to_lang:{self.to_lang}")
|
self.logger.info(f"开始提取术语表,to_lang:{self.to_lang}")
|
||||||
|
|||||||
@@ -300,6 +300,9 @@ class GlossaryAgentConfigPayload(BaseModel):
|
|||||||
system_proxy_enable: bool = Field(
|
system_proxy_enable: bool = Field(
|
||||||
default=default_params["system_proxy_enable"], description="是否使用系统代理", examples=[True, False]
|
default=default_params["system_proxy_enable"], description="是否使用系统代理", examples=[True, False]
|
||||||
)
|
)
|
||||||
|
custom_prompt:Optional[str]=Field(
|
||||||
|
default=None,description="生成术语表的用户自定义提示词"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 1. 定义所有工作流共享的基础参数
|
# 1. 定义所有工作流共享的基础参数
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user