添加默认logger

This commit is contained in:
xunbu
2025-09-09 17:32:08 +08:00
parent 787009dcaa
commit e82f6f1d15
4 changed files with 88 additions and 84 deletions

View File

@@ -40,7 +40,7 @@ class PartialAgentResultError(ValueError):
@dataclass(kw_only=True)
class AgentConfig:
logger: logging.Logger
logger: logging.Logger = global_logger
base_url: str
api_key: str | None = None
model_id: str
@@ -111,14 +111,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取cached_tokens
# 格式1: input_tokens_details.cached_tokens
if (
"input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"]
"input_tokens_details" in usage
and "cached_tokens" in usage["input_tokens_details"]
):
cached_tokens = usage["input_tokens_details"]["cached_tokens"]
# 格式2: prompt_tokens_details.cached_tokens
elif (
"prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"]
"prompt_tokens_details" in usage
and "cached_tokens" in usage["prompt_tokens_details"]
):
cached_tokens = usage["prompt_tokens_details"]["cached_tokens"]
# 格式3: prompt_cache_hit_tokens (直接在usage下)
@@ -128,14 +128,14 @@ def extract_token_info(response_data: dict) -> tuple[int, int, int, int]:
# 尝试从不同格式获取reasoning_tokens
# 格式1: output_tokens_details.reasoning_tokens
if (
"output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"]
"output_tokens_details" in usage
and "reasoning_tokens" in usage["output_tokens_details"]
):
reasoning_tokens = usage["output_tokens_details"]["reasoning_tokens"]
# 格式2: completion_tokens_details.reasoning_tokens
elif (
"completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"]
"completion_tokens_details" in usage
and "reasoning_tokens" in usage["completion_tokens_details"]
):
reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
@@ -153,11 +153,11 @@ class TokenCounter:
self.logger = logger
def add(
self,
input_tokens: int,
cached_tokens: int,
output_tokens: int,
reasoning_tokens: int,
self,
input_tokens: int,
cached_tokens: int,
output_tokens: int,
reasoning_tokens: int,
):
with self.lock:
self.input_tokens += input_tokens
@@ -236,7 +236,7 @@ class Agent:
self.max_concurrent = config.concurrent
self.timeout = httpx.Timeout(connect=5, read=config.timeout, write=300, pool=10)
self.thinking = config.thinking
self.logger = config.logger or global_logger
self.logger = config.logger
self.total_error_counter = TotalErrorCounter(logger=self.logger)
# 新增:用于统计最终未解决的错误
self.unresolved_error_lock = Lock()
@@ -254,7 +254,7 @@ class Agent:
data[field_thinking] = val_disable
def _prepare_request_data(
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
self, prompt: str, system_prompt: str, temperature=None, top_p=0.9
):
if temperature is None:
temperature = self.temperature
@@ -276,16 +276,16 @@ class Agent:
return headers, data
async def send_async(
self,
client: httpx.AsyncClient,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None,
self,
client: httpx.AsyncClient,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
best_partial_result: dict | None = None,
) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
@@ -422,13 +422,13 @@ class Agent:
)
async def send_prompts_async(
self,
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
self,
prompts: list[str],
system_prompt: str | None = None,
max_concurrent: int | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]:
max_concurrent = (
self.max_concurrent if max_concurrent is None else max_concurrent
@@ -439,7 +439,7 @@ class Agent:
)
self.logger.info(f"预计发送{total}个请求,并发请求数:{max_concurrent}")
self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR
len(prompts) // MAX_REQUESTS_PER_ERROR
)
# 新增:在每次批量发送前重置计数器
@@ -459,9 +459,8 @@ class Agent:
)
async with httpx.AsyncClient(
trust_env=False, proxies=proxies, verify=False, limits=limits
trust_env=False, proxies=proxies, verify=False, limits=limits
) as client:
async def send_with_semaphore(p_text: str):
async with semaphore:
result = await self.send_async(
@@ -491,24 +490,24 @@ class Agent:
# 新增打印token使用统计
token_stats = self.token_counter.get_stats()
self.logger.info(
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
f"总计: {token_stats['total_tokens']/1000:.2f}K"
f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), "
f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), "
f"总计: {token_stats['total_tokens'] / 1000:.2f}K"
)
return results
def send(
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler=None,
result_handler=None,
error_result_handler=None,
best_partial_result: dict | None = None,
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str = None,
retry=True,
retry_count=0,
pre_send_handler=None,
result_handler=None,
error_result_handler=None,
best_partial_result: dict | None = None,
) -> Any:
if system_prompt is None:
system_prompt = self.system_prompt
@@ -641,14 +640,14 @@ class Agent:
)
def _send_prompt_count(
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str,
count: PromptsCounter,
pre_send_handler,
result_handler,
error_result_handler,
self,
client: httpx.Client,
prompt: str,
system_prompt: None | str,
count: PromptsCounter,
pre_send_handler,
result_handler,
error_result_handler,
) -> Any:
result = self.send(
client,
@@ -662,12 +661,12 @@ class Agent:
return result
def send_prompts(
self,
prompts: list[str],
system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
self,
prompts: list[str],
system_prompt: str | None = None,
pre_send_handler: PreSendHandlerType = None,
result_handler: ResultHandlerType = None,
error_result_handler: ErrorResultHandlerType = None,
) -> list[Any]:
self.logger.info(
f"base-url:{self.baseurl},model-id:{self.model_id},concurrent:{self.max_concurrent},temperature:{self.temperature}"
@@ -676,7 +675,7 @@ class Agent:
f"预计发送{len(prompts)}个请求,并发请求数:{self.max_concurrent}"
)
self.total_error_counter.max_errors_count = (
len(prompts) // MAX_REQUESTS_PER_ERROR
len(prompts) // MAX_REQUESTS_PER_ERROR
)
# 新增:在每次批量发送前重置计数器
@@ -697,7 +696,7 @@ class Agent:
)
proxies = get_httpx_proxies() if USE_PROXY else None
with httpx.Client(
trust_env=False, proxies=proxies, verify=False, limits=limits
trust_env=False, proxies=proxies, verify=False, limits=limits
) as client:
clients = itertools.repeat(client, len(prompts))
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
@@ -721,9 +720,9 @@ class Agent:
# 新增打印token使用统计
token_stats = self.token_counter.get_stats()
self.logger.info(
f"Token使用统计 - 输入: {token_stats['input_tokens']/1000:.2f}K(含cached: {token_stats['cached_tokens']/1000:.2f}K), "
f"输出: {token_stats['output_tokens']/1000:.2f}K(含reasoning: {token_stats['reasoning_tokens']/1000:.2f}K), "
f"总计: {token_stats['total_tokens']/1000:.2f}K"
f"Token使用统计 - 输入: {token_stats['input_tokens'] / 1000:.2f}K(含cached: {token_stats['cached_tokens'] / 1000:.2f}K), "
f"输出: {token_stats['output_tokens'] / 1000:.2f}K(含reasoning: {token_stats['reasoning_tokens'] / 1000:.2f}K), "
f"总计: {token_stats['total_tokens'] / 1000:.2f}K"
)
return output_list

View File

@@ -12,7 +12,7 @@ from docutranslate.logger import global_logger
@dataclass(kw_only=True)
class ConverterConfig(ABC):
logger: Logger | None = None
logger: Logger = global_logger
@abstractmethod
def gethash(self) -> Hashable:
@@ -23,9 +23,7 @@ class Converter(ABC):
def __init__(self, config: ConverterConfig | None = None):
self.config = config
if config:
self.logger = config.logger or global_logger
else:
self.logger = global_logger
self.logger = config.logger
@abstractmethod
def convert(self, document: Document) -> Document:

View File

@@ -1,29 +1,35 @@
# SPDX-FileCopyrightText: 2025 QinHan
# SPDX-License-Identifier: MPL-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from logging import Logger
from typing import TypeVar,Generic
from abc import ABC,abstractmethod
from typing import TypeVar, Generic
from docutranslate.ir.document import Document
from docutranslate.logger import global_logger
@dataclass(kw_only=True)
class TranslatorConfig:
logger:Logger|None=None
logger: Logger = global_logger
T=TypeVar('T',bound=Document)
class Translator(ABC,Generic[T]):
T = TypeVar('T', bound=Document)
class Translator(ABC, Generic[T]):
"""
翻译中间文本原地替换Translator不做格式转换
"""
def __init__(self,config:TranslatorConfig|None=None):
self.config=config
self.logger=config.logger or global_logger
def __init__(self, config: TranslatorConfig | None = None):
self.config = config
self.logger = config.logger or global_logger
@abstractmethod
def translate(self, document:T) -> Document:
def translate(self, document: T) -> Document:
...
@abstractmethod
async def translate_async(self, document: T) -> Document:
...

View File

@@ -9,11 +9,12 @@ from typing import Self, Generic, TypeVar
from docutranslate.exporter.base import Exporter
from docutranslate.ir.attachment_manager import AttachMentManager
from docutranslate.ir.document import Document
from docutranslate.logger import global_logger
@dataclass(kw_only=True)
class WorkflowConfig:
logger: Logger | None = None
logger: Logger = global_logger
T_Config = TypeVar("T_Config", bound=WorkflowConfig)