"""数据库连接池监控模块""" import time import threading from typing import Dict, Any from datetime import datetime from sqlalchemy import event from sqlalchemy.engine import Engine from .logging_config import get_logger logger = get_logger(__name__) db_logger = get_logger("sqlalchemy.engine") # 全局监控数据 _pool_stats = { 'total_connections': 0, 'checked_in': 0, 'checked_out': 0, 'overflow': 0, 'slow_queries': [], 'connection_errors': [], 'peak_connections': 0, 'last_reset': datetime.now() } _alert_thresholds = { 'pool_usage_percent': 80, 'slow_query_time': 5.0, 'alert_cooldown': 300 } _last_alerts = {} _engines = [] def register_engine(engine): """注册引擎用于监控""" _engines.append(engine) def get_pool_status() -> Dict[str, Any]: """获取连接池状态""" stats = { 'total': 0, 'checked_in': 0, 'checked_out': 0, 'overflow': 0, 'usage_percent': 0 } for engine in _engines: if hasattr(engine.pool, 'size'): pool = engine.pool stats['total'] += pool.size() if callable(pool.size) else pool.size stats['checked_in'] += pool.checkedin() if hasattr(pool, 'checkedin') else 0 stats['checked_out'] += pool.checkedout() if hasattr(pool, 'checkedout') else 0 stats['overflow'] += pool.overflow() if hasattr(pool, 'overflow') else 0 if stats['total'] > 0: stats['usage_percent'] = round((stats['checked_out'] / stats['total']) * 100, 2) if stats['checked_out'] > _pool_stats['peak_connections']: _pool_stats['peak_connections'] = stats['checked_out'] return stats def check_pool_alerts(): """检查连接池告警""" current_time = time.time() stats = get_pool_status() if stats.get('usage_percent', 0) >= _alert_thresholds['pool_usage_percent']: alert_key = 'pool_usage' if alert_key not in _last_alerts or (current_time - _last_alerts.get(alert_key, 0)) > _alert_thresholds['alert_cooldown']: db_logger.warning( f"数据库连接池告警: 使用率 {stats['usage_percent']}% 超过阈值 " f"(已使用: {stats['checked_out']}/{stats['total']})" ) _last_alerts[alert_key] = current_time def log_slow_query(sql: str, duration: float): """记录慢查询""" _pool_stats['slow_queries'].append({ 'sql': sql[:200] if len(sql) > 200 else sql, 'duration': duration, 'timestamp': time.time() }) if len(_pool_stats['slow_queries']) > 1000: _pool_stats['slow_queries'] = _pool_stats['slow_queries'][-1000:] def log_connection_error(error: str): """记录连接错误""" _pool_stats['connection_errors'].append({ 'error': error, 'timestamp': time.time() }) if len(_pool_stats['connection_errors']) > 100: _pool_stats['connection_errors'] = _pool_stats['connection_errors'][-100:] def get_monitoring_report() -> Dict[str, Any]: """获取监控报告""" stats = get_pool_status() current_time = time.time() recent_slow_queries = [q for q in _pool_stats['slow_queries'] if (current_time - q['timestamp']) < 300] recent_errors = [e for e in _pool_stats['connection_errors'] if (current_time - e['timestamp']) < 300] return { 'timestamp': datetime.now().isoformat(), 'pool_status': stats, 'peak_connections': _pool_stats['peak_connections'], 'recent_5min': { 'slow_queries_count': len(recent_slow_queries), 'connection_errors_count': len(recent_errors) }, 'last_reset': _pool_stats['last_reset'].isoformat() } def monitoring_task(): """定时监控任务""" while True: try: check_pool_alerts() time.sleep(30) except Exception as e: db_logger.error(f"数据库监控任务异常: {e}") time.sleep(60) def start_monitoring(): """启动后台监控""" # 延迟导入避免循环依赖 from .database import railway_engine, tunnel_engine register_engine(railway_engine) register_engine(tunnel_engine) monitor_thread = threading.Thread(target=monitoring_task, daemon=True) monitor_thread.start() db_logger.info("数据库连接池监控已启动") # SQL执行时间监控 _query_start_times = {} @event.listens_for(Engine, "before_cursor_execute") def receive_before_cursor_execute(conn, cursor, statement, params, context, executemany): _query_start_times[id(cursor)] = time.time() @event.listens_for(Engine, "after_cursor_execute") def receive_after_cursor_execute(conn, cursor, statement, params, context, executemany): start_time = _query_start_times.pop(id(cursor), None) if start_time: duration = time.time() - start_time if duration >= _alert_thresholds['slow_query_time']: log_slow_query(statement, duration) db_logger.warning(f"慢查询: {duration:.2f}s - {statement[:100]}...") @event.listens_for(Engine, "handle_error") def receive_handle_error(exception_context): error_msg = str(exception_context.original_exception) log_connection_error(error_msg) db_logger.error(f"数据库错误: {error_msg}") def log_pool_status(): """记录连接池状态到日志""" stats = get_pool_status() db_logger.info( f"数据库连接池状态: 使用率 {stats['usage_percent']}% " f"(已用: {stats['checked_out']}, 空闲: {stats['checked_in']}, 总计: {stats['total']})" )