添加默认logger

This commit is contained in:
xunbu
2025-09-09 17:32:08 +08:00
parent 787009dcaa
commit e82f6f1d15
4 changed files with 88 additions and 84 deletions

View File

@@ -40,7 +40,7 @@ class PartialAgentResultError(ValueError):
@dataclass(kw_only=True) @dataclass(kw_only=True)
class AgentConfig: class AgentConfig:
logger: logging.Logger logger: logging.Logger = global_logger
base_url: str base_url: str
api_key: str | None = None api_key: str | None = None
model_id: str model_id: str
@@ -111,14 +111,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取cached_tokens # 尝试从不同格式获取cached_tokens
# 格式1: input_tokens_details.cached_tokens # 格式1: input_tokens_details.cached_tokens
if ( if (
"input_tokens_details" in usage "input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"] and "cached_tokens" in usage["input_tokens_details"]
): ):
cached_tokens = usage["input_tokens_details"]["cached_tokens"] cached_tokens = usage["input_tokens_details"]["cached_tokens"]
# 格式2: prompt_tokens_details.cached_tokens # 格式2: prompt_tokens_details.cached_tokens
elif ( elif (
"prompt_tokens_details" in usage "prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"] and "cached_tokens" in usage["prompt_tokens_details"]
): ):
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
# 格式3: prompt_cache_hit_tokens (直接在usage下) # 格式3: prompt_cache_hit_tokens (直接在usage下)
@@ -128,14 +128,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取reasoning_tokens # 尝试从不同格式获取reasoning_tokens
# 格式1: output_tokens_details.reasoning_tokens # 格式1: output_tokens_details.reasoning_tokens
if ( if (
"output_tokens_details" in usage "output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"] and "reasoning_tokens" in usage["output_tokens_details"]
): ):
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"] reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
# 格式2: completion_tokens_details.reasoning_tokens # 格式2: completion_tokens_details.reasoning_tokens
elif ( elif (
"completion_tokens_details" in usage "completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"] and "reasoning_tokens" in usage["completion_tokens_details"]
): ):
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
@@ -153,11 +153,11 @@ class TokenCounter:
self.logger = logger self.logger = logger
def add( def add(
self, self,
input_tokens: int, input_tokens: int,
cached_tokens: int, cached_tokens: int,
output_tokens: int, output_tokens: int,
reasoning_tokens: int, reasoning_tokens: int,
): ):
with self.lock: with self.lock:
self.input_tokens += input_tokens self.input_tokens += input_tokens
@@ -236,7 +236,7 @@ class Agent:
self.max_concurrent = config.concurrent self.max_concurrent = config.concurrent
self.timeout = httpx.Timeout(connect=5, read=config.timeout, write=300, pool=10) self.timeout = httpx.Timeout(connect=5, read=config.timeout, write=300, pool=10)
self.thinking = config.thinking self.thinking = config.thinking
self.logger = config.logger or global_logger self.logger = config.logger
self.total_error_counter = TotalErrorCounter(logger=self.logger) self.total_error_counter = TotalErrorCounter(logger=self.logger)
# 新增:用于统计最终未解决的错误 # 新增:用于统计最终未解决的错误
self.unresolved_error_lock = Lock() self.unresolved_error_lock = Lock()
@@ -254,7 +254,7 @@ class Agent:
data[field_thinking] = val_disable data[field_thinking] = val_disable
def _prepare_request_data( def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9 self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
): ):
if temperature is None: if temperature is None:
temperature = self.temperature temperature = self.temperature
@@ -276,16 +276,16 @@ class Agent:
return headers, data return headers, data
async def send_async( async def send_async(
self, self,
client: httpx.AsyncClient, client: httpx.AsyncClient,
prompt: str, prompt: str,
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None, best_partial_result: dict | None = None,
) -> Any: ) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
@@ -422,13 +422,13 @@ class Agent:
) )
async def send_prompts_async( async def send_prompts_async(
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
max_concurrent: int | None = None, max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]: ) -> list[Any]:
max_concurrent = ( max_concurrent = (
self.max_concurrent if max_concurrent is None else max_concurrent self.max_concurrent if max_concurrent is None else max_concurrent
@@ -439,7 +439,7 @@ class Agent:
) )
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}") self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count = ( self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR len(prompts) // MAX_REQUESTS_PER_ERROR
) )
# 新增:在每次批量发送前重置计数器 # 新增:在每次批量发送前重置计数器
@@ -459,9 +459,8 @@ class Agent:
) )
async with httpx.AsyncClient( async with httpx.AsyncClient(
trust_env=False, proxies=proxies, verify=False, limits=limits trust_env=False, proxies=proxies, verify=False, limits=limits
) as client: ) as client:
async def send_with_semaphore(p_text: str): async def send_with_semaphore(p_text: str):
async with semaphore: async with semaphore:
result = await self.send_async( result = await self.send_async(
@@ -491,24 +490,24 @@ class Agent:
# 新增打印token使用统计 # 新增打印token使用统计
token_stats = self.token_counter.get_stats() token_stats = self.token_counter.get_stats()
self.logger.info( self.logger.info(
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), " f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), "
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), " f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), "
f"总计: {token_stats['total_tokens']/1000:.2f}K" f"总计: {token_stats['total_tokens'] / 1000:.2f}K"
) )
return results return results
def send( def send(
self, self,
client: httpx.Client, client: httpx.Client,
prompt: str, prompt: str,
system_prompt: None | str = None, system_prompt: None | str = None,
retry=True, retry=True,
retry_count=0, retry_count=0,
pre_send_handler=None, pre_send_handler=None,
result_handler=None, result_handler=None,
error_result_handler=None, error_result_handler=None,
best_partial_result: dict | None = None, best_partial_result: dict | None = None,
) -> Any: ) -> Any:
if system_prompt is None: if system_prompt is None:
system_prompt = self.system_prompt system_prompt = self.system_prompt
@@ -641,14 +640,14 @@ class Agent:
) )
def _send_prompt_count( def _send_prompt_count(
self, self,
client: httpx.Client, client: httpx.Client,
prompt: str, prompt: str,
system_prompt: None | str, system_prompt: None | str,
count: PromptsCounter, count: PromptsCounter,
pre_send_handler, pre_send_handler,
result_handler, result_handler,
error_result_handler, error_result_handler,
) -> Any: ) -> Any:
result = self.send( result = self.send(
client, client,
@@ -662,12 +661,12 @@ class Agent:
return result return result
def send_prompts( def send_prompts(
self, self,
prompts: list[str], prompts: list[str],
system_prompt: str | None = None, system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None, pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None, result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None, error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]: ) -> list[Any]:
self.logger.info( self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}" f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
@@ -676,7 +675,7 @@ class Agent:
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}" f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
) )
self.total_error_counter.max_errors_count = ( self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR len(prompts) // MAX_REQUESTS_PER_ERROR
) )
# 新增:在每次批量发送前重置计数器 # 新增:在每次批量发送前重置计数器
@@ -697,7 +696,7 @@ class Agent:
) )
proxies = get_httpx_proxies() if USE_PROXY else None proxies = get_httpx_proxies() if USE_PROXY else None
with httpx.Client( with httpx.Client(
trust_env=False, proxies=proxies, verify=False, limits=limits trust_env=False, proxies=proxies, verify=False, limits=limits
) as client: ) as client:
clients = itertools.repeat(client, len(prompts)) clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
@@ -721,9 +720,9 @@ class Agent:
# 新增打印token使用统计 # 新增打印token使用统计
token_stats = self.token_counter.get_stats() token_stats = self.token_counter.get_stats()
self.logger.info( self.logger.info(
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), " f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), "
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), " f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), "
f"总计: {token_stats['total_tokens']/1000:.2f}K" f"总计: {token_stats['total_tokens'] / 1000:.2f}K"
) )
return output_list return output_list

View File

@@ -12,7 +12,7 @@ from docutranslate.logger import global_logger
@dataclass(kw_only=True) @dataclass(kw_only=True)
class ConverterConfig(ABC): class ConverterConfig(ABC):
logger: Logger | None = None logger: Logger = global_logger
@abstractmethod @abstractmethod
def gethash(self) -> Hashable: def gethash(self) -> Hashable:
@@ -23,9 +23,7 @@ class Converter(ABC):
def __init__(self, config: ConverterConfig | None = None): def __init__(self, config: ConverterConfig | None = None):
self.config = config self.config = config
if config: if config:
self.logger = config.logger or global_logger self.logger = config.logger
else:
self.logger = global_logger
@abstractmethod @abstractmethod
def convert(self, document: Document) -> Document: def convert(self, document: Document) -> Document:

View File

@@ -1,29 +1,35 @@
# SPDX-FileCopyrightText: 2025 QinHan # SPDX-FileCopyrightText: 2025 QinHan
# SPDX-License-Identifier: MPL-2.0 # SPDX-License-Identifier: MPL-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
from typing import TypeVar,Generic from typing import TypeVar, Generic
from abc import ABC,abstractmethod
from docutranslate.ir.document import Document from docutranslate.ir.document import Document
from docutranslate.logger import global_logger from docutranslate.logger import global_logger
@dataclass(kw_only=True) @dataclass(kw_only=True)
class TranslatorConfig: class TranslatorConfig:
logger:Logger|None=None logger: Logger = global_logger
T=TypeVar('T',bound=Document)
class Translator(ABC,Generic[T]): T = TypeVar('T', bound=Document)
class Translator(ABC, Generic[T]):
""" """
翻译中间文本原地替换Translator不做格式转换 翻译中间文本原地替换Translator不做格式转换
""" """
def __init__(self,config:TranslatorConfig|None=None):
self.config=config def __init__(self, config: TranslatorConfig | None = None):
self.logger=config.logger or global_logger self.config = config
self.logger = config.logger or global_logger
@abstractmethod @abstractmethod
def translate(self, document:T) -> Document: def translate(self, document: T) -> Document:
... ...
@abstractmethod @abstractmethod
async def translate_async(self, document: T) -> Document: async def translate_async(self, document: T) -> Document:
... ...

View File

@@ -9,11 +9,12 @@ from typing import Self, Generic, TypeVar
from docutranslate.exporter.base import Exporter from docutranslate.exporter.base import Exporter
from docutranslate.ir.attachment_manager import AttachMentManager from docutranslate.ir.attachment_manager import AttachMentManager
from docutranslate.ir.document import Document from docutranslate.ir.document import Document
from docutranslate.logger import global_logger
@dataclass(kw_only=True) @dataclass(kw_only=True)
class WorkflowConfig: class WorkflowConfig:
logger: Logger | None = None logger: Logger = global_logger
T_Config = TypeVar("T_Config", bound=WorkflowConfig) T_Config = TypeVar("T_Config", bound=WorkflowConfig)