diff --git a/docutranslate/agents/agent.py b/docutranslate/agents/agent.py index d8389da..61dbb98 100644 --- a/docutranslate/agents/agent.py +++ b/docutranslate/agents/agent.py @@ -40,7 +40,7 @@ class PartialAgentResultError(ValueError): @dataclass(kw_only=True) class AgentConfig: - logger: logging.Logger + logger: logging.Logger = global_logger base_url: str api_key: str | None = None model_id: str @@ -111,14 +111,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: # 尝试从不同格式获取cached_tokens # 格式1: input_tokens_details.cached_tokens if ( - "input_tokens_details" in usage - and "cached_tokens" in usage["input_tokens_details"] + "input_tokens_details" in usage + and "cached_tokens" in usage["input_tokens_details"] ): cached_tokens = usage["input_tokens_details"]["cached_tokens"] # 格式2: prompt_tokens_details.cached_tokens elif ( - "prompt_tokens_details" in usage - and "cached_tokens" in usage["prompt_tokens_details"] + "prompt_tokens_details" in usage + and "cached_tokens" in usage["prompt_tokens_details"] ): cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] # 格式3: prompt_cache_hit_tokens (直接在usage下) @@ -128,14 +128,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]: # 尝试从不同格式获取reasoning_tokens # 格式1: output_tokens_details.reasoning_tokens if ( - "output_tokens_details" in usage - and "reasoning_tokens" in usage["output_tokens_details"] + "output_tokens_details" in usage + and "reasoning_tokens" in usage["output_tokens_details"] ): reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"] # 格式2: completion_tokens_details.reasoning_tokens elif ( - "completion_tokens_details" in usage - and "reasoning_tokens" in usage["completion_tokens_details"] + "completion_tokens_details" in usage + and "reasoning_tokens" in usage["completion_tokens_details"] ): reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] @@ -153,11 +153,11 @@ class TokenCounter: self.logger = logger def add( - self, - input_tokens: int, - cached_tokens: int, - output_tokens: int, - reasoning_tokens: int, + self, + input_tokens: int, + cached_tokens: int, + output_tokens: int, + reasoning_tokens: int, ): with self.lock: self.input_tokens += input_tokens @@ -236,7 +236,7 @@ class Agent: self.max_concurrent = config.concurrent self.timeout = httpx.Timeout(connect=5, read=config.timeout, write=300, pool=10) self.thinking = config.thinking - self.logger = config.logger or global_logger + self.logger = config.logger self.total_error_counter = TotalErrorCounter(logger=self.logger) # 新增:用于统计最终未解决的错误 self.unresolved_error_lock = Lock() @@ -254,7 +254,7 @@ class Agent: data[field_thinking] = val_disable def _prepare_request_data( - self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 + self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 ): if temperature is None: temperature = self.temperature @@ -276,16 +276,16 @@ class Agent: return headers, data async def send_async( - self, - client: httpx.AsyncClient, - prompt: str, - system_prompt: None | str = None, - retry=True, - retry_count=0, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, - best_partial_result: dict | None = None, + self, + client: httpx.AsyncClient, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, + best_partial_result: dict | None = None, ) -> Any: if system_prompt is None: system_prompt = self.system_prompt @@ -422,13 +422,13 @@ class Agent: ) async def send_prompts_async( - self, - prompts: list[str], - system_prompt: str | None = None, - max_concurrent: int | None = None, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, + self, + prompts: list[str], + system_prompt: str | None = None, + max_concurrent: int | None = None, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: max_concurrent = ( self.max_concurrent if max_concurrent is None else max_concurrent @@ -439,7 +439,7 @@ class Agent: ) self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") self.total_error_counter.max_errors_count = ( - len(prompts) // MAX_REQUESTS_PER_ERROR + len(prompts) // MAX_REQUESTS_PER_ERROR ) # 新增:在每次批量发送前重置计数器 @@ -459,9 +459,8 @@ class Agent: ) async with httpx.AsyncClient( - trust_env=False, proxies=proxies, verify=False, limits=limits + trust_env=False, proxies=proxies, verify=False, limits=limits ) as client: - async def send_with_semaphore(p_text: str): async with semaphore: result = await self.send_async( @@ -491,24 +490,24 @@ class Agent: # 新增:打印token使用统计 token_stats = self.token_counter.get_stats() self.logger.info( - f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), " - f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), " - f"总计: {token_stats['total_tokens']/1000:.2f}K" + f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), " + f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), " + f"总计: {token_stats['total_tokens'] / 1000:.2f}K" ) return results def send( - self, - client: httpx.Client, - prompt: str, - system_prompt: None | str = None, - retry=True, - retry_count=0, - pre_send_handler=None, - result_handler=None, - error_result_handler=None, - best_partial_result: dict | None = None, + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str = None, + retry=True, + retry_count=0, + pre_send_handler=None, + result_handler=None, + error_result_handler=None, + best_partial_result: dict | None = None, ) -> Any: if system_prompt is None: system_prompt = self.system_prompt @@ -641,14 +640,14 @@ class Agent: ) def _send_prompt_count( - self, - client: httpx.Client, - prompt: str, - system_prompt: None | str, - count: PromptsCounter, - pre_send_handler, - result_handler, - error_result_handler, + self, + client: httpx.Client, + prompt: str, + system_prompt: None | str, + count: PromptsCounter, + pre_send_handler, + result_handler, + error_result_handler, ) -> Any: result = self.send( client, @@ -662,12 +661,12 @@ class Agent: return result def send_prompts( - self, - prompts: list[str], - system_prompt: str | None = None, - pre_send_handler: PreSendHandlerType = None, - result_handler: ResultHandlerType = None, - error_result_handler: ErrorResultHandlerType = None, + self, + prompts: list[str], + system_prompt: str | None = None, + pre_send_handler: PreSendHandlerType = None, + result_handler: ResultHandlerType = None, + error_result_handler: ErrorResultHandlerType = None, ) -> list[Any]: self.logger.info( f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}" @@ -676,7 +675,7 @@ class Agent: f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}" ) self.total_error_counter.max_errors_count = ( - len(prompts) // MAX_REQUESTS_PER_ERROR + len(prompts) // MAX_REQUESTS_PER_ERROR ) # 新增:在每次批量发送前重置计数器 @@ -697,7 +696,7 @@ class Agent: ) proxies = get_httpx_proxies() if USE_PROXY else None with httpx.Client( - trust_env=False, proxies=proxies, verify=False, limits=limits + trust_env=False, proxies=proxies, verify=False, limits=limits ) as client: clients = itertools.repeat(client, len(prompts)) with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: @@ -721,9 +720,9 @@ class Agent: # 新增:打印token使用统计 token_stats = self.token_counter.get_stats() self.logger.info( - f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), " - f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), " - f"总计: {token_stats['total_tokens']/1000:.2f}K" + f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), " + f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), " + f"总计: {token_stats['total_tokens'] / 1000:.2f}K" ) return output_list diff --git a/docutranslate/converter/base.py b/docutranslate/converter/base.py index c807148..9f106b1 100644 --- a/docutranslate/converter/base.py +++ b/docutranslate/converter/base.py @@ -12,7 +12,7 @@ from docutranslate.logger import global_logger @dataclass(kw_only=True) class ConverterConfig(ABC): - logger: Logger | None = None + logger: Logger = global_logger @abstractmethod def gethash(self) -> Hashable: @@ -23,13 +23,11 @@ class Converter(ABC): def __init__(self, config: ConverterConfig | None = None): self.config = config if config: - self.logger = config.logger or global_logger - else: - self.logger = global_logger + self.logger = config.logger @abstractmethod def convert(self, document: Document) -> Document: ... async def convert_async(self, document: Document) -> Document: - ... \ No newline at end of file + ... diff --git a/docutranslate/translator/base.py b/docutranslate/translator/base.py index 0a97caf..0a47be8 100644 --- a/docutranslate/translator/base.py +++ b/docutranslate/translator/base.py @@ -1,29 +1,35 @@ # SPDX-FileCopyrightText: 2025 QinHan # SPDX-License-Identifier: MPL-2.0 +from abc import ABC, abstractmethod from dataclasses import dataclass from logging import Logger -from typing import TypeVar,Generic -from abc import ABC,abstractmethod +from typing import TypeVar, Generic + from docutranslate.ir.document import Document from docutranslate.logger import global_logger @dataclass(kw_only=True) class TranslatorConfig: - logger:Logger|None=None + logger: Logger = global_logger -T=TypeVar('T',bound=Document) -class Translator(ABC,Generic[T]): +T = TypeVar('T', bound=Document) + + +class Translator(ABC, Generic[T]): """ 翻译中间文本(原地替换),Translator不做格式转换 """ - def __init__(self,config:TranslatorConfig|None=None): - self.config=config - self.logger=config.logger or global_logger + + def __init__(self, config: TranslatorConfig | None = None): + self.config = config + self.logger = config.logger or global_logger + @abstractmethod - def translate(self, document:T) -> Document: + def translate(self, document: T) -> Document: ... + @abstractmethod async def translate_async(self, document: T) -> Document: - ... \ No newline at end of file + ... diff --git a/docutranslate/workflow/base.py b/docutranslate/workflow/base.py index ba42dde..57d7c78 100644 --- a/docutranslate/workflow/base.py +++ b/docutranslate/workflow/base.py @@ -9,11 +9,12 @@ from typing import Self, Generic, TypeVar from docutranslate.exporter.base import Exporter from docutranslate.ir.attachment_manager import AttachMentManager from docutranslate.ir.document import Document +from docutranslate.logger import global_logger @dataclass(kw_only=True) class WorkflowConfig: - logger: Logger | None = None + logger: Logger = global_logger T_Config = TypeVar("T_Config", bound=WorkflowConfig)