diff --git a/backend/common/dataclasses.py b/backend/common/dataclasses.py index b838e9ba..389be685 100644 --- a/backend/common/dataclasses.py +++ b/backend/common/dataclasses.py @@ -70,6 +70,6 @@ class UploadUrl: class SnowflakeInfo: timestamp: int datetime: str - cluster_id: int - node_id: int + datacenter_id: int + worker_id: int sequence: int diff --git a/backend/core/conf.py b/backend/core/conf.py index 6f22cb64..b093cdb8 100644 --- a/backend/core/conf.py +++ b/backend/core/conf.py @@ -52,6 +52,15 @@ class Settings(BaseSettings): # Redis REDIS_TIMEOUT: int = 5 + # .env Snowflake + SNOWFLAKE_DATACENTER_ID: int | None = None + SNOWFLAKE_WORKER_ID: int | None = None + + # Snowflake + SNOWFLAKE_REDIS_PREFIX: str = 'fba:snowflake' + SNOWFLAKE_HEARTBEAT_INTERVAL_SECONDS: int = 30 + SNOWFLAKE_NODE_TTL_SECONDS: int = 60 + # .env Token TOKEN_SECRET_KEY: str # 密钥 secrets.token_urlsafe(32) diff --git a/backend/core/registrar.py b/backend/core/registrar.py index 1e20e500..dd2408a9 100644 --- a/backend/core/registrar.py +++ b/backend/core/registrar.py @@ -34,6 +34,7 @@ from backend.utils.health_check import ensure_unique_route_names, http_limit_callback from backend.utils.openapi import simplify_operation_ids from backend.utils.serializers import MsgSpecJSONResponse +from backend.utils.snowflake import snowflake @asynccontextmanager @@ -57,11 +58,17 @@ async def register_init(app: FastAPI) -> AsyncGenerator[None, None]: http_callback=http_limit_callback, ) + # 初始化 snowflake 节点 + await snowflake.init() + # 创建操作日志任务 create_task(OperaLogMiddleware.consumer()) yield + # 释放 snowflake 节点 + await snowflake.shutdown() + # 关闭 redis 连接 await redis_client.aclose() diff --git a/backend/utils/snowflake.py b/backend/utils/snowflake.py index e91883ce..267bd8cd 100644 --- a/backend/utils/snowflake.py +++ b/backend/utils/snowflake.py @@ -1,15 +1,22 @@ +import asyncio +import datetime +import os +import threading import time from dataclasses import dataclass from backend.common.dataclasses import SnowflakeInfo from backend.common.exception import errors +from backend.common.log import log from backend.core.conf import settings +from backend.database.redis import redis_client +from backend.utils.timezone import timezone @dataclass(frozen=True) class SnowflakeConfig: - """雪花算法配置类""" + """雪花算法配置类,采用 Twitter 原版 Snowflake 64 位 ID 位分配配置(通用标准)""" # 位分配 WORKER_ID_BITS: int = 5 @@ -29,81 +36,184 @@ class SnowflakeConfig: # 元年时间戳 EPOCH: int = 1262275200000 - # 默认值 - DEFAULT_DATACENTER_ID: int = 1 - DEFAULT_WORKER_ID: int = 0 - DEFAULT_SEQUENCE: int = 0 + # 时钟回拨容忍阈值,应对 NTP 自动同步引起的正常回跳(非标准) + CLOCK_BACKWARD_TOLERANCE_MS: int = 10_000 + + +class SnowflakeNodeManager: + """雪花算法节点管理器,负责从 Redis 分配和管理节点 ID""" + + def __init__(self) -> None: + """初始化节点管理器""" + self.datacenter_id: int | None = None + self.worker_id: int | None = None + self.node_redis_prefix: str = f'{settings.SNOWFLAKE_REDIS_PREFIX}:nodes' + self._heartbeat_task: asyncio.Task | None = None + + async def acquire_node_id(self) -> tuple[int, int]: + """从 Redis 获取可用的 datacenter_id 和 worker_id""" + occupied_nodes = set() + async for key in redis_client.scan_iter(match=f'{self.node_redis_prefix}:*'): + parts = key.split(':') + if len(parts) >= 5: + try: + datacenter_id = int(parts[-2]) + worker_id = int(parts[-1]) + occupied_nodes.add((datacenter_id, worker_id)) + except ValueError: + continue + + # 顺序查找第一个可用的 ID 组合 + for datacenter_id in range(SnowflakeConfig.MAX_DATACENTER_ID + 1): + for worker_id in range(SnowflakeConfig.MAX_WORKER_ID + 1): + if (datacenter_id, worker_id) not in occupied_nodes and await self._register(datacenter_id, worker_id): + return datacenter_id, worker_id + + raise errors.ServerError(msg='无可用的雪花算法节点,节点已耗尽') + + async def _register(self, datacenter_id: int, worker_id: int) -> bool: + key = f'{self.node_redis_prefix}:{datacenter_id}:{worker_id}' + value = f'pid:{os.getpid()}-ts:{timezone.now().timestamp()}' + return await redis_client.set(key, value, nx=True, ex=settings.SNOWFLAKE_NODE_TTL_SECONDS) + + async def start_heartbeat(self, datacenter_id: int, worker_id: int) -> None: + """启动节点心跳""" + self.datacenter_id = datacenter_id + self.worker_id = worker_id + + async def heartbeat() -> None: + key = f'{self.node_redis_prefix}:{datacenter_id}:{worker_id}' + while True: + await asyncio.sleep(settings.SNOWFLAKE_HEARTBEAT_INTERVAL_SECONDS) + try: + await redis_client.expire(key, settings.SNOWFLAKE_NODE_TTL_SECONDS) + log.debug(f'雪花算法节点心跳任务开始:datacenter_id={datacenter_id}, worker_id={worker_id}') + except Exception as e: + log.error(f'雪花算法节点心跳任务失败:{e}') + + self._heartbeat_task = asyncio.create_task(heartbeat()) + + async def release(self) -> None: + """释放节点""" + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + log.debug(f'雪花算法节点心跳任务释放:datacenter_id={self.datacenter_id}, worker_id={self.worker_id}') + + if self.datacenter_id is not None and self.worker_id is not None: + key = f'{self.node_redis_prefix}:{self.datacenter_id}:{self.worker_id}' + await redis_client.delete(key) class Snowflake: """雪花算法类""" - def __init__( - self, - cluster_id: int = SnowflakeConfig.DEFAULT_DATACENTER_ID, - node_id: int = SnowflakeConfig.DEFAULT_WORKER_ID, - sequence: int = SnowflakeConfig.DEFAULT_SEQUENCE, - ) -> None: - """ - 初始化雪花算法生成器 - - :param cluster_id: 集群 ID (0-31) - :param node_id: 节点 ID (0-31) - :param sequence: 起始序列号 - """ - if cluster_id < 0 or cluster_id > SnowflakeConfig.MAX_DATACENTER_ID: - raise errors.RequestError(msg=f'集群编号必须在 0-{SnowflakeConfig.MAX_DATACENTER_ID} 之间') - if node_id < 0 or node_id > SnowflakeConfig.MAX_WORKER_ID: - raise errors.RequestError(msg=f'节点编号必须在 0-{SnowflakeConfig.MAX_WORKER_ID} 之间') - - self.node_id = node_id - self.cluster_id = cluster_id - self.sequence = sequence - self.last_timestamp = -1 + def __init__(self) -> None: + """初始化雪花算法""" + self.datacenter_id: int | None = None + self.worker_id: int | None = None + self.sequence: int = 0 + self.last_timestamp: int = -1 + + self._lock = threading.Lock() + self._initialized = False + self._node_manager: SnowflakeNodeManager | None = None + self._auto_allocated = False # 标记是否由 Redis 自动分配 ID + + async def init(self) -> None: + """初始化雪花算法""" + if self._initialized: + return + + with self._lock: + # 环境变量固定分配 + if settings.SNOWFLAKE_DATACENTER_ID is not None and settings.SNOWFLAKE_WORKER_ID is not None: + self.datacenter_id = settings.SNOWFLAKE_DATACENTER_ID + self.worker_id = settings.SNOWFLAKE_WORKER_ID + log.debug( + f'雪花算法使用环境变量固定节点:datacenter_id={self.datacenter_id}, worker_id={self.worker_id}' + ) + elif (settings.SNOWFLAKE_DATACENTER_ID is not None and settings.SNOWFLAKE_WORKER_ID is None) or ( + settings.SNOWFLAKE_DATACENTER_ID is None and settings.SNOWFLAKE_WORKER_ID is not None + ): + log.error('雪花算法 datacenter_id 和 worker_id 配置错误,只允许同时非 None 或同时为 None') + raise errors.ServerError(msg='雪花算法配置失败,请联系系统管理员') + else: + # Redis 动态分配 + self._node_manager = SnowflakeNodeManager() + self.datacenter_id, self.worker_id = await self._node_manager.acquire_node_id() + self._auto_allocated = True + await self._node_manager.start_heartbeat(self.datacenter_id, self.worker_id) + log.debug( + f'雪花算法使用 Redis 动态分配节点:datacenter_id={self.datacenter_id}, worker_id={self.worker_id}' + ) + + # 严格校验范围 + if not (0 <= self.datacenter_id <= SnowflakeConfig.MAX_DATACENTER_ID): + log.error(f'雪花算法 datacenter_id 配置失败,必须在 0~{SnowflakeConfig.MAX_DATACENTER_ID} 之间') + raise errors.ServerError(msg='雪花算法数据中心配置失败,请联系系统管理员') + if not (0 <= self.worker_id <= SnowflakeConfig.MAX_WORKER_ID): + log.error(f'雪花算法 worker_id 配置失败,必须在 0~{SnowflakeConfig.MAX_WORKER_ID} 之间') + raise errors.ServerError(msg='雪花算法工作机器配置失败,请联系系统管理员') + + self._initialized = True + + async def shutdown(self) -> None: + """释放 Redis 节点""" + if self._node_manager and self._auto_allocated: + await self._node_manager.release() @staticmethod - def _current_millis() -> int: - """返回当前毫秒时间戳""" - return int(time.time() * 1000) + def _current_ms() -> int: + return int(timezone.now().timestamp() * 1000) - def _next_millis(self, last_timestamp: int) -> int: - """ - 等待至下一毫秒 - - :param last_timestamp: 上次生成 ID 的时间戳 - :return: - """ - timestamp = self._current_millis() - while timestamp <= last_timestamp: - time.sleep((last_timestamp - timestamp + 1) / 1000.0) - timestamp = self._current_millis() - return timestamp + def _till_next_ms(self, last_timestamp: int) -> int: + """等待直到下一毫秒""" + ts = self._current_ms() + while ts <= last_timestamp: + time.sleep(0.0001) + ts = self._current_ms() + return ts def generate(self) -> int: """生成雪花 ID""" - timestamp = self._current_millis() - - if timestamp < self.last_timestamp: - raise errors.ServerError(msg=f'系统时间倒退,拒绝生成 ID 直到 {self.last_timestamp}') - - if timestamp == self.last_timestamp: - self.sequence = (self.sequence + 1) & SnowflakeConfig.SEQUENCE_MASK - if self.sequence == 0: - timestamp = self._next_millis(self.last_timestamp) - else: - self.sequence = 0 - - self.last_timestamp = timestamp - - return ( - ((timestamp - SnowflakeConfig.EPOCH) << SnowflakeConfig.TIMESTAMP_LEFT_SHIFT) - | (self.cluster_id << SnowflakeConfig.DATACENTER_ID_SHIFT) - | (self.node_id << SnowflakeConfig.WORKER_ID_SHIFT) - | self.sequence - ) + if not self._initialized: + raise errors.ServerError(msg='雪花 ID 生成失败,雪花算法未初始化') + + with self._lock: + timestamp = self._current_ms() + + # 时钟回拨处理 + if timestamp < self.last_timestamp: + back_ms = self.last_timestamp - timestamp + if back_ms <= SnowflakeConfig.CLOCK_BACKWARD_TOLERANCE_MS: + log.warning(f'检测到时钟回拨 {back_ms} ms,等待恢复...') + timestamp = self._till_next_ms(self.last_timestamp) + else: + raise errors.ServerError(msg=f'雪花 ID 生成失败,时钟回拨超过 {back_ms} ms,请立即联系系统管理员') + + # 同毫秒内序列号递增 + if timestamp == self.last_timestamp: + self.sequence = (self.sequence + 1) & SnowflakeConfig.SEQUENCE_MASK + if self.sequence == 0: + timestamp = self._till_next_ms(self.last_timestamp) + else: + self.sequence = 0 + + self.last_timestamp = timestamp + + # 组合 64 位 ID + return ( + ((timestamp - SnowflakeConfig.EPOCH) << SnowflakeConfig.TIMESTAMP_LEFT_SHIFT) + | (self.datacenter_id << SnowflakeConfig.DATACENTER_ID_SHIFT) + | (self.worker_id << SnowflakeConfig.WORKER_ID_SHIFT) + | self.sequence + ) @staticmethod - def parse_id(snowflake_id: int) -> SnowflakeInfo: + def parse(snowflake_id: int) -> SnowflakeInfo: """ 解析雪花 ID,获取其包含的详细信息 @@ -111,15 +221,15 @@ def parse_id(snowflake_id: int) -> SnowflakeInfo: :return: """ timestamp = (snowflake_id >> SnowflakeConfig.TIMESTAMP_LEFT_SHIFT) + SnowflakeConfig.EPOCH - cluster_id = (snowflake_id >> SnowflakeConfig.DATACENTER_ID_SHIFT) & SnowflakeConfig.MAX_DATACENTER_ID - node_id = (snowflake_id >> SnowflakeConfig.WORKER_ID_SHIFT) & SnowflakeConfig.MAX_WORKER_ID + datacenter_id = (snowflake_id >> SnowflakeConfig.DATACENTER_ID_SHIFT) & SnowflakeConfig.MAX_DATACENTER_ID + worker_id = (snowflake_id >> SnowflakeConfig.WORKER_ID_SHIFT) & SnowflakeConfig.MAX_WORKER_ID sequence = snowflake_id & SnowflakeConfig.SEQUENCE_MASK return SnowflakeInfo( timestamp=timestamp, - datetime=time.strftime(settings.DATETIME_FORMAT, time.localtime(timestamp / 1000)), - cluster_id=cluster_id, - node_id=node_id, + datetime=timezone.to_str(datetime.datetime.fromtimestamp(timestamp / 1000, timezone.tz_info)), + datacenter_id=datacenter_id, + worker_id=worker_id, sequence=sequence, )