统一接口名称
This commit is contained in:
@@ -50,6 +50,7 @@ class AgentConfig:
|
|||||||
thinking: ThinkingMode = "disable"
|
thinking: ThinkingMode = "disable"
|
||||||
retry: int = 2
|
retry: int = 2
|
||||||
system_proxy_enable: bool = False
|
system_proxy_enable: bool = False
|
||||||
|
force_json: bool = False # 应输出json格式时强制ai输出json
|
||||||
|
|
||||||
|
|
||||||
class TotalErrorCounter:
|
class TotalErrorCounter:
|
||||||
@@ -267,7 +268,7 @@ class Agent:
|
|||||||
system_prompt: None | str = None,
|
system_prompt: None | str = None,
|
||||||
retry=True,
|
retry=True,
|
||||||
retry_count=0,
|
retry_count=0,
|
||||||
json_format=False,
|
force_json=False,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None,
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
@@ -279,7 +280,7 @@ class Agent:
|
|||||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||||
# print(f"system_prompt:\n{system_prompt}")
|
# print(f"system_prompt:\n{system_prompt}")
|
||||||
# print(f"【测试】prompt:\n{prompt}")
|
# print(f"【测试】prompt:\n{prompt}")
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
|
headers, data = self._prepare_request_data(prompt, system_prompt, json_format=force_json)
|
||||||
should_retry = False
|
should_retry = False
|
||||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||||
current_partial_result = None
|
current_partial_result = None
|
||||||
@@ -392,6 +393,7 @@ class Agent:
|
|||||||
system_prompt,
|
system_prompt,
|
||||||
retry=True,
|
retry=True,
|
||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
|
force_json=force_json,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
@@ -419,7 +421,7 @@ class Agent:
|
|||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_concurrent: int | None = None,
|
max_concurrent: int | None = None,
|
||||||
json_format=False,
|
force_json=False,
|
||||||
pre_send_handler: PreSendHandlerType = None,
|
pre_send_handler: PreSendHandlerType = None,
|
||||||
result_handler: ResultHandlerType = None,
|
result_handler: ResultHandlerType = None,
|
||||||
error_result_handler: ErrorResultHandlerType = None,
|
error_result_handler: ErrorResultHandlerType = None,
|
||||||
@@ -429,7 +431,7 @@ class Agent:
|
|||||||
)
|
)
|
||||||
total = len(prompts)
|
total = len(prompts)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{json_format}"
|
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{max_concurrent},temperature:{self.temperature},system_proxy:{self.system_proxy_enable},json_output:{force_json}"
|
||||||
)
|
)
|
||||||
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
|
||||||
self.total_error_counter.max_errors_count = (
|
self.total_error_counter.max_errors_count = (
|
||||||
@@ -462,7 +464,7 @@ class Agent:
|
|||||||
client=client,
|
client=client,
|
||||||
prompt=p_text,
|
prompt=p_text,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
json_format=json_format,
|
force_json=force_json,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
@@ -503,7 +505,7 @@ class Agent:
|
|||||||
system_prompt: None | str = None,
|
system_prompt: None | str = None,
|
||||||
retry=True,
|
retry=True,
|
||||||
retry_count=0,
|
retry_count=0,
|
||||||
json_format=False,
|
force_json=False,
|
||||||
pre_send_handler=None,
|
pre_send_handler=None,
|
||||||
result_handler=None,
|
result_handler=None,
|
||||||
error_result_handler=None,
|
error_result_handler=None,
|
||||||
@@ -514,7 +516,7 @@ class Agent:
|
|||||||
if pre_send_handler:
|
if pre_send_handler:
|
||||||
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
system_prompt, prompt = pre_send_handler(system_prompt, prompt)
|
||||||
|
|
||||||
headers, data = self._prepare_request_data(prompt, system_prompt,json_format=json_format)
|
headers, data = self._prepare_request_data(prompt, system_prompt, json_format=force_json)
|
||||||
should_retry = False
|
should_retry = False
|
||||||
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
is_hard_error = False # 新增标志,用于区分是否为硬错误
|
||||||
current_partial_result = None
|
current_partial_result = None
|
||||||
@@ -621,6 +623,7 @@ class Agent:
|
|||||||
system_prompt,
|
system_prompt,
|
||||||
retry=True,
|
retry=True,
|
||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
|
force_json=force_json,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
@@ -648,7 +651,7 @@ class Agent:
|
|||||||
client: httpx.Client,
|
client: httpx.Client,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system_prompt: None | str,
|
system_prompt: None | str,
|
||||||
json_format,
|
force_json,
|
||||||
count: PromptsCounter,
|
count: PromptsCounter,
|
||||||
pre_send_handler,
|
pre_send_handler,
|
||||||
result_handler,
|
result_handler,
|
||||||
@@ -658,7 +661,7 @@ class Agent:
|
|||||||
client,
|
client,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
json_format=json_format,
|
force_json=force_json,
|
||||||
pre_send_handler=pre_send_handler,
|
pre_send_handler=pre_send_handler,
|
||||||
result_handler=result_handler,
|
result_handler=result_handler,
|
||||||
error_result_handler=error_result_handler,
|
error_result_handler=error_result_handler,
|
||||||
|
|||||||
@@ -66,16 +66,18 @@ def get_target_segments(result: str):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(kw_only=True)
|
||||||
class GlossaryAgentConfig(AgentConfig):
|
class GlossaryAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: str = None
|
custom_prompt: str = None
|
||||||
|
force_json: bool = False
|
||||||
|
|
||||||
|
|
||||||
class GlossaryAgent(Agent):
|
class GlossaryAgent(Agent):
|
||||||
def __init__(self, config: GlossaryAgentConfig):
|
def __init__(self, config: GlossaryAgentConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.to_lang = config.to_lang
|
self.to_lang = config.to_lang
|
||||||
|
self.force_json=config.force_json
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
# Role
|
# Role
|
||||||
You are a professional glossary extractor
|
You are a professional glossary extractor
|
||||||
@@ -114,7 +116,7 @@ You are a professional glossary extractor
|
|||||||
result = {}
|
result = {}
|
||||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False), self.to_lang) for chunk in chunks]
|
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False), self.to_lang) for chunk in chunks]
|
||||||
translated_chunks = super().send_prompts(prompts=prompts,
|
translated_chunks = super().send_prompts(prompts=prompts, json_format=self.force_json,
|
||||||
result_handler=self._result_handler,
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
for chunk in translated_chunks:
|
for chunk in translated_chunks:
|
||||||
@@ -138,7 +140,7 @@ You are a professional glossary extractor
|
|||||||
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
indexed_originals, chunks, merged_indices_list = await asyncio.to_thread(segments2json_chunks, segments,
|
||||||
chunk_size)
|
chunk_size)
|
||||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False), self.to_lang) for chunk in chunks]
|
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False), self.to_lang) for chunk in chunks]
|
||||||
translated_chunks = await super().send_prompts_async(prompts=prompts,
|
translated_chunks = await super().send_prompts_async(prompts=prompts, force_json=self.force_json,
|
||||||
result_handler=self._result_handler,
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
for chunk in translated_chunks:
|
for chunk in translated_chunks:
|
||||||
|
|||||||
@@ -90,19 +90,19 @@ def get_target_segments(result: str):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(kw_only=True)
|
||||||
class SegmentsTranslateAgentConfig(AgentConfig):
|
class SegmentsTranslateAgentConfig(AgentConfig):
|
||||||
to_lang: str
|
to_lang: str
|
||||||
custom_prompt: str | None = None
|
custom_prompt: str | None = None
|
||||||
glossary_dict: dict[str, str] | None = None
|
glossary_dict: dict[str, str] | None = None
|
||||||
json_format:bool = False
|
force_json:bool = False
|
||||||
|
|
||||||
|
|
||||||
class SegmentsTranslateAgent(Agent):
|
class SegmentsTranslateAgent(Agent):
|
||||||
def __init__(self, config: SegmentsTranslateAgentConfig):
|
def __init__(self, config: SegmentsTranslateAgentConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.to_lang = config.to_lang
|
self.to_lang = config.to_lang
|
||||||
self.json_format = config.json_format
|
self.force_json = config.force_json
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
# Role
|
# Role
|
||||||
- You are a professional, authentic machine translation engine.
|
- You are a professional, authentic machine translation engine.
|
||||||
@@ -199,7 +199,7 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
def send_segments(self, segments: list[str], chunk_size: int) -> list[str]:
|
||||||
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
indexed_originals, chunks, merged_indices_list = segments2json_chunks(segments, chunk_size)
|
||||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
||||||
translated_chunks = super().send_prompts(prompts=prompts, json_format=self.json_format,
|
translated_chunks = super().send_prompts(prompts=prompts, json_format=self.force_json,
|
||||||
pre_send_handler=self._pre_send_handler,
|
pre_send_handler=self._pre_send_handler,
|
||||||
result_handler=self._result_handler,
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
@@ -238,7 +238,7 @@ class SegmentsTranslateAgent(Agent):
|
|||||||
chunk_size)
|
chunk_size)
|
||||||
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
prompts = [generate_prompt(json.dumps(chunk, ensure_ascii=False, indent=0), self.to_lang) for chunk in chunks]
|
||||||
|
|
||||||
translated_chunks = await super().send_prompts_async(prompts=prompts, json_format=self.json_format,
|
translated_chunks = await super().send_prompts_async(prompts=prompts, force_json=self.force_json,
|
||||||
pre_send_handler=self._pre_send_handler,
|
pre_send_handler=self._pre_send_handler,
|
||||||
result_handler=self._result_handler,
|
result_handler=self._result_handler,
|
||||||
error_result_handler=self._error_result_handler)
|
error_result_handler=self._error_result_handler)
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ async def lifespan(app: FastAPI):
|
|||||||
global_logger.propagate = False
|
global_logger.propagate = False
|
||||||
global_logger.setLevel(logging.INFO)
|
global_logger.setLevel(logging.INFO)
|
||||||
print("应用启动完成,多任务状态已初始化。")
|
print("应用启动完成,多任务状态已初始化。")
|
||||||
print(f"服务接口文档: http://1227.0.0.1:{app.state.port_to_use}/docs")
|
print(f"服务接口文档: http://127.0.0.1:{app.state.port_to_use}/docs")
|
||||||
print(f"请用浏览器访问 http://127.0.0.1:{app.state.port_to_use}\n")
|
print(f"请用浏览器访问 http://127.0.0.1:{app.state.port_to_use}\n")
|
||||||
yield
|
yield
|
||||||
# 清理任何可能残留的临时目录
|
# 清理任何可能残留的临时目录
|
||||||
@@ -311,6 +311,9 @@ class GlossaryAgentConfigPayload(BaseModel):
|
|||||||
custom_prompt: Optional[str] = Field(
|
custom_prompt: Optional[str] = Field(
|
||||||
default=None, description="生成术语表的用户自定义提示词"
|
default=None, description="生成术语表的用户自定义提示词"
|
||||||
)
|
)
|
||||||
|
force_json: bool = Field(
|
||||||
|
default=False, description="强制Agent输出JSON格式的术语表。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 1. 定义所有工作流共享的基础参数
|
# 1. 定义所有工作流共享的基础参数
|
||||||
@@ -736,6 +739,7 @@ class TranslateServiceRequest(BaseModel):
|
|||||||
"timeout": default_params["timeout"],
|
"timeout": default_params["timeout"],
|
||||||
"thinking": "default",
|
"thinking": "default",
|
||||||
"retry": default_params["retry"],
|
"retry": default_params["retry"],
|
||||||
|
"force_json": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -40,7 +40,7 @@ class AssTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ class AiTranslatorConfig(TranslatorConfig, AgentConfig):
|
|||||||
glossary_generate_enable: bool = False
|
glossary_generate_enable: bool = False
|
||||||
glossary_agent_config: GlossaryAgentConfig | None = None
|
glossary_agent_config: GlossaryAgentConfig | None = None
|
||||||
skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项
|
skip_translate: bool = False # 当skip_translate为False时base_url、model_id为必填项
|
||||||
force_json:bool=False # 应输出json格式时强制ai输出json
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Document)
|
T = TypeVar("T", bound=Document)
|
||||||
|
|
||||||
@@ -64,6 +63,7 @@ class AiTranslator(Translator[T]):
|
|||||||
logger=self.logger,
|
logger=self.logger,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
|
force_json=config.force_json,
|
||||||
)
|
)
|
||||||
self.glossary_agent = GlossaryAgent(glossary_agent_config)
|
self.glossary_agent = GlossaryAgent(glossary_agent_config)
|
||||||
|
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class DocxTranslator(AiTranslator):
|
|||||||
api_key=config.api_key, model_id=config.model_id, temperature=config.temperature,
|
api_key=config.api_key, model_id=config.model_id, temperature=config.temperature,
|
||||||
thinking=config.thinking, concurrent=config.concurrent, timeout=config.timeout,
|
thinking=config.thinking, concurrent=config.concurrent, timeout=config.timeout,
|
||||||
logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry,
|
logger=self.logger, glossary_dict=config.glossary_dict, retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,json_format=config.force_json
|
system_proxy_enable=config.system_proxy_enable, force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class EpubTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class HtmlTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class JsonTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.json_paths = config.json_paths
|
self.json_paths = config.json_paths
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class SrtTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class TXTTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class XlsxTranslator(AiTranslator):
|
|||||||
glossary_dict=config.glossary_dict,
|
glossary_dict=config.glossary_dict,
|
||||||
retry=config.retry,
|
retry=config.retry,
|
||||||
system_proxy_enable=config.system_proxy_enable,
|
system_proxy_enable=config.system_proxy_enable,
|
||||||
json_format=config.force_json
|
force_json=config.force_json
|
||||||
)
|
)
|
||||||
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
self.translate_agent = SegmentsTranslateAgent(agent_config)
|
||||||
self.insert_mode = config.insert_mode
|
self.insert_mode = config.insert_mode
|
||||||
|
|||||||
Reference in New Issue
Block a user