Skip to content

Commit 16bcaf9

Browse files
authored
Add distributed deployment support for snowflake ID (#927)
* feat: Add support distributed deployment for Snowflake * Update the algorithm implementation * Remove duplicate codes and update error messages
1 parent 551dc51 commit 16bcaf9

File tree

4 files changed

+196
-70
lines changed

4 files changed

+196
-70
lines changed

backend/common/dataclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ class UploadUrl:
7070
class SnowflakeInfo:
7171
timestamp: int
7272
datetime: str
73-
cluster_id: int
74-
node_id: int
73+
datacenter_id: int
74+
worker_id: int
7575
sequence: int

backend/core/conf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ class Settings(BaseSettings):
5252
# Redis
5353
REDIS_TIMEOUT: int = 5
5454

55+
# .env Snowflake
56+
SNOWFLAKE_DATACENTER_ID: int | None = None
57+
SNOWFLAKE_WORKER_ID: int | None = None
58+
59+
# Snowflake
60+
SNOWFLAKE_REDIS_PREFIX: str = 'fba:snowflake'
61+
SNOWFLAKE_HEARTBEAT_INTERVAL_SECONDS: int = 30
62+
SNOWFLAKE_NODE_TTL_SECONDS: int = 60
63+
5564
# .env Token
5665
TOKEN_SECRET_KEY: str # 密钥 secrets.token_urlsafe(32)
5766

backend/core/registrar.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from backend.utils.health_check import ensure_unique_route_names, http_limit_callback
3535
from backend.utils.openapi import simplify_operation_ids
3636
from backend.utils.serializers import MsgSpecJSONResponse
37+
from backend.utils.snowflake import snowflake
3738

3839

3940
@asynccontextmanager
@@ -57,11 +58,17 @@ async def register_init(app: FastAPI) -> AsyncGenerator[None, None]:
5758
http_callback=http_limit_callback,
5859
)
5960

61+
# 初始化 snowflake 节点
62+
await snowflake.init()
63+
6064
# 创建操作日志任务
6165
create_task(OperaLogMiddleware.consumer())
6266

6367
yield
6468

69+
# 释放 snowflake 节点
70+
await snowflake.shutdown()
71+
6572
# 关闭 redis 连接
6673
await redis_client.aclose()
6774

backend/utils/snowflake.py

Lines changed: 178 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1+
import asyncio
2+
import datetime
3+
import os
4+
import threading
15
import time
26

37
from dataclasses import dataclass
48

59
from backend.common.dataclasses import SnowflakeInfo
610
from backend.common.exception import errors
11+
from backend.common.log import log
712
from backend.core.conf import settings
13+
from backend.database.redis import redis_client
14+
from backend.utils.timezone import timezone
815

916

1017
@dataclass(frozen=True)
1118
class SnowflakeConfig:
12-
"""雪花算法配置类"""
19+
"""雪花算法配置类,采用 Twitter 原版 Snowflake 64 位 ID 位分配配置(通用标准)"""
1320

1421
# 位分配
1522
WORKER_ID_BITS: int = 5
@@ -29,97 +36,200 @@ class SnowflakeConfig:
2936
# 元年时间戳
3037
EPOCH: int = 1262275200000
3138

32-
# 默认值
33-
DEFAULT_DATACENTER_ID: int = 1
34-
DEFAULT_WORKER_ID: int = 0
35-
DEFAULT_SEQUENCE: int = 0
39+
# 时钟回拨容忍阈值,应对 NTP 自动同步引起的正常回跳(非标准)
40+
CLOCK_BACKWARD_TOLERANCE_MS: int = 10_000
41+
42+
43+
class SnowflakeNodeManager:
44+
"""雪花算法节点管理器,负责从 Redis 分配和管理节点 ID"""
45+
46+
def __init__(self) -> None:
47+
"""初始化节点管理器"""
48+
self.datacenter_id: int | None = None
49+
self.worker_id: int | None = None
50+
self.node_redis_prefix: str = f'{settings.SNOWFLAKE_REDIS_PREFIX}:nodes'
51+
self._heartbeat_task: asyncio.Task | None = None
52+
53+
async def acquire_node_id(self) -> tuple[int, int]:
54+
"""从 Redis 获取可用的 datacenter_id 和 worker_id"""
55+
occupied_nodes = set()
56+
async for key in redis_client.scan_iter(match=f'{self.node_redis_prefix}:*'):
57+
parts = key.split(':')
58+
if len(parts) >= 5:
59+
try:
60+
datacenter_id = int(parts[-2])
61+
worker_id = int(parts[-1])
62+
occupied_nodes.add((datacenter_id, worker_id))
63+
except ValueError:
64+
continue
65+
66+
# 顺序查找第一个可用的 ID 组合
67+
for datacenter_id in range(SnowflakeConfig.MAX_DATACENTER_ID + 1):
68+
for worker_id in range(SnowflakeConfig.MAX_WORKER_ID + 1):
69+
if (datacenter_id, worker_id) not in occupied_nodes and await self._register(datacenter_id, worker_id):
70+
return datacenter_id, worker_id
71+
72+
raise errors.ServerError(msg='无可用的雪花算法节点,节点已耗尽')
73+
74+
async def _register(self, datacenter_id: int, worker_id: int) -> bool:
75+
key = f'{self.node_redis_prefix}:{datacenter_id}:{worker_id}'
76+
value = f'pid:{os.getpid()}-ts:{timezone.now().timestamp()}'
77+
return await redis_client.set(key, value, nx=True, ex=settings.SNOWFLAKE_NODE_TTL_SECONDS)
78+
79+
async def start_heartbeat(self, datacenter_id: int, worker_id: int) -> None:
80+
"""启动节点心跳"""
81+
self.datacenter_id = datacenter_id
82+
self.worker_id = worker_id
83+
84+
async def heartbeat() -> None:
85+
key = f'{self.node_redis_prefix}:{datacenter_id}:{worker_id}'
86+
while True:
87+
await asyncio.sleep(settings.SNOWFLAKE_HEARTBEAT_INTERVAL_SECONDS)
88+
try:
89+
await redis_client.expire(key, settings.SNOWFLAKE_NODE_TTL_SECONDS)
90+
log.debug(f'雪花算法节点心跳任务开始:datacenter_id={datacenter_id}, worker_id={worker_id}')
91+
except Exception as e:
92+
log.error(f'雪花算法节点心跳任务失败:{e}')
93+
94+
self._heartbeat_task = asyncio.create_task(heartbeat())
95+
96+
async def release(self) -> None:
97+
"""释放节点"""
98+
if self._heartbeat_task:
99+
self._heartbeat_task.cancel()
100+
try:
101+
await self._heartbeat_task
102+
except asyncio.CancelledError:
103+
log.debug(f'雪花算法节点心跳任务释放:datacenter_id={self.datacenter_id}, worker_id={self.worker_id}')
104+
105+
if self.datacenter_id is not None and self.worker_id is not None:
106+
key = f'{self.node_redis_prefix}:{self.datacenter_id}:{self.worker_id}'
107+
await redis_client.delete(key)
36108

37109

38110
class Snowflake:
39111
"""雪花算法类"""
40112

41-
def __init__(
42-
self,
43-
cluster_id: int = SnowflakeConfig.DEFAULT_DATACENTER_ID,
44-
node_id: int = SnowflakeConfig.DEFAULT_WORKER_ID,
45-
sequence: int = SnowflakeConfig.DEFAULT_SEQUENCE,
46-
) -> None:
47-
"""
48-
初始化雪花算法生成器
49-
50-
:param cluster_id: 集群 ID (0-31)
51-
:param node_id: 节点 ID (0-31)
52-
:param sequence: 起始序列号
53-
"""
54-
if cluster_id < 0 or cluster_id > SnowflakeConfig.MAX_DATACENTER_ID:
55-
raise errors.RequestError(msg=f'集群编号必须在 0-{SnowflakeConfig.MAX_DATACENTER_ID} 之间')
56-
if node_id < 0 or node_id > SnowflakeConfig.MAX_WORKER_ID:
57-
raise errors.RequestError(msg=f'节点编号必须在 0-{SnowflakeConfig.MAX_WORKER_ID} 之间')
58-
59-
self.node_id = node_id
60-
self.cluster_id = cluster_id
61-
self.sequence = sequence
62-
self.last_timestamp = -1
113+
def __init__(self) -> None:
114+
"""初始化雪花算法"""
115+
self.datacenter_id: int | None = None
116+
self.worker_id: int | None = None
117+
self.sequence: int = 0
118+
self.last_timestamp: int = -1
119+
120+
self._lock = threading.Lock()
121+
self._initialized = False
122+
self._node_manager: SnowflakeNodeManager | None = None
123+
self._auto_allocated = False # 标记是否由 Redis 自动分配 ID
124+
125+
async def init(self) -> None:
126+
"""初始化雪花算法"""
127+
if self._initialized:
128+
return
129+
130+
with self._lock:
131+
# 环境变量固定分配
132+
if settings.SNOWFLAKE_DATACENTER_ID is not None and settings.SNOWFLAKE_WORKER_ID is not None:
133+
self.datacenter_id = settings.SNOWFLAKE_DATACENTER_ID
134+
self.worker_id = settings.SNOWFLAKE_WORKER_ID
135+
log.debug(
136+
f'雪花算法使用环境变量固定节点:datacenter_id={self.datacenter_id}, worker_id={self.worker_id}'
137+
)
138+
elif (settings.SNOWFLAKE_DATACENTER_ID is not None and settings.SNOWFLAKE_WORKER_ID is None) or (
139+
settings.SNOWFLAKE_DATACENTER_ID is None and settings.SNOWFLAKE_WORKER_ID is not None
140+
):
141+
log.error('雪花算法 datacenter_id 和 worker_id 配置错误,只允许同时非 None 或同时为 None')
142+
raise errors.ServerError(msg='雪花算法配置失败,请联系系统管理员')
143+
else:
144+
# Redis 动态分配
145+
self._node_manager = SnowflakeNodeManager()
146+
self.datacenter_id, self.worker_id = await self._node_manager.acquire_node_id()
147+
self._auto_allocated = True
148+
await self._node_manager.start_heartbeat(self.datacenter_id, self.worker_id)
149+
log.debug(
150+
f'雪花算法使用 Redis 动态分配节点:datacenter_id={self.datacenter_id}, worker_id={self.worker_id}'
151+
)
152+
153+
# 严格校验范围
154+
if not (0 <= self.datacenter_id <= SnowflakeConfig.MAX_DATACENTER_ID):
155+
log.error(f'雪花算法 datacenter_id 配置失败,必须在 0~{SnowflakeConfig.MAX_DATACENTER_ID} 之间')
156+
raise errors.ServerError(msg='雪花算法数据中心配置失败,请联系系统管理员')
157+
if not (0 <= self.worker_id <= SnowflakeConfig.MAX_WORKER_ID):
158+
log.error(f'雪花算法 worker_id 配置失败,必须在 0~{SnowflakeConfig.MAX_WORKER_ID} 之间')
159+
raise errors.ServerError(msg='雪花算法工作机器配置失败,请联系系统管理员')
160+
161+
self._initialized = True
162+
163+
async def shutdown(self) -> None:
164+
"""释放 Redis 节点"""
165+
if self._node_manager and self._auto_allocated:
166+
await self._node_manager.release()
63167

64168
@staticmethod
65-
def _current_millis() -> int:
66-
"""返回当前毫秒时间戳"""
67-
return int(time.time() * 1000)
169+
def _current_ms() -> int:
170+
return int(timezone.now().timestamp() * 1000)
68171

69-
def _next_millis(self, last_timestamp: int) -> int:
70-
"""
71-
等待至下一毫秒
72-
73-
:param last_timestamp: 上次生成 ID 的时间戳
74-
:return:
75-
"""
76-
timestamp = self._current_millis()
77-
while timestamp <= last_timestamp:
78-
time.sleep((last_timestamp - timestamp + 1) / 1000.0)
79-
timestamp = self._current_millis()
80-
return timestamp
172+
def _till_next_ms(self, last_timestamp: int) -> int:
173+
"""等待直到下一毫秒"""
174+
ts = self._current_ms()
175+
while ts <= last_timestamp:
176+
time.sleep(0.0001)
177+
ts = self._current_ms()
178+
return ts
81179

82180
def generate(self) -> int:
83181
"""生成雪花 ID"""
84-
timestamp = self._current_millis()
85-
86-
if timestamp < self.last_timestamp:
87-
raise errors.ServerError(msg=f'系统时间倒退,拒绝生成 ID 直到 {self.last_timestamp}')
88-
89-
if timestamp == self.last_timestamp:
90-
self.sequence = (self.sequence + 1) & SnowflakeConfig.SEQUENCE_MASK
91-
if self.sequence == 0:
92-
timestamp = self._next_millis(self.last_timestamp)
93-
else:
94-
self.sequence = 0
95-
96-
self.last_timestamp = timestamp
97-
98-
return (
99-
((timestamp - SnowflakeConfig.EPOCH) << SnowflakeConfig.TIMESTAMP_LEFT_SHIFT)
100-
| (self.cluster_id << SnowflakeConfig.DATACENTER_ID_SHIFT)
101-
| (self.node_id << SnowflakeConfig.WORKER_ID_SHIFT)
102-
| self.sequence
103-
)
182+
if not self._initialized:
183+
raise errors.ServerError(msg='雪花 ID 生成失败,雪花算法未初始化')
184+
185+
with self._lock:
186+
timestamp = self._current_ms()
187+
188+
# 时钟回拨处理
189+
if timestamp < self.last_timestamp:
190+
back_ms = self.last_timestamp - timestamp
191+
if back_ms <= SnowflakeConfig.CLOCK_BACKWARD_TOLERANCE_MS:
192+
log.warning(f'检测到时钟回拨 {back_ms} ms,等待恢复...')
193+
timestamp = self._till_next_ms(self.last_timestamp)
194+
else:
195+
raise errors.ServerError(msg=f'雪花 ID 生成失败,时钟回拨超过 {back_ms} ms,请立即联系系统管理员')
196+
197+
# 同毫秒内序列号递增
198+
if timestamp == self.last_timestamp:
199+
self.sequence = (self.sequence + 1) & SnowflakeConfig.SEQUENCE_MASK
200+
if self.sequence == 0:
201+
timestamp = self._till_next_ms(self.last_timestamp)
202+
else:
203+
self.sequence = 0
204+
205+
self.last_timestamp = timestamp
206+
207+
# 组合 64 位 ID
208+
return (
209+
((timestamp - SnowflakeConfig.EPOCH) << SnowflakeConfig.TIMESTAMP_LEFT_SHIFT)
210+
| (self.datacenter_id << SnowflakeConfig.DATACENTER_ID_SHIFT)
211+
| (self.worker_id << SnowflakeConfig.WORKER_ID_SHIFT)
212+
| self.sequence
213+
)
104214

105215
@staticmethod
106-
def parse_id(snowflake_id: int) -> SnowflakeInfo:
216+
def parse(snowflake_id: int) -> SnowflakeInfo:
107217
"""
108218
解析雪花 ID,获取其包含的详细信息
109219
110220
:param snowflake_id: 雪花ID
111221
:return:
112222
"""
113223
timestamp = (snowflake_id >> SnowflakeConfig.TIMESTAMP_LEFT_SHIFT) + SnowflakeConfig.EPOCH
114-
cluster_id = (snowflake_id >> SnowflakeConfig.DATACENTER_ID_SHIFT) & SnowflakeConfig.MAX_DATACENTER_ID
115-
node_id = (snowflake_id >> SnowflakeConfig.WORKER_ID_SHIFT) & SnowflakeConfig.MAX_WORKER_ID
224+
datacenter_id = (snowflake_id >> SnowflakeConfig.DATACENTER_ID_SHIFT) & SnowflakeConfig.MAX_DATACENTER_ID
225+
worker_id = (snowflake_id >> SnowflakeConfig.WORKER_ID_SHIFT) & SnowflakeConfig.MAX_WORKER_ID
116226
sequence = snowflake_id & SnowflakeConfig.SEQUENCE_MASK
117227

118228
return SnowflakeInfo(
119229
timestamp=timestamp,
120-
datetime=time.strftime(settings.DATETIME_FORMAT, time.localtime(timestamp / 1000)),
121-
cluster_id=cluster_id,
122-
node_id=node_id,
230+
datetime=timezone.to_str(datetime.datetime.fromtimestamp(timestamp / 1000, timezone.tz_info)),
231+
datacenter_id=datacenter_id,
232+
worker_id=worker_id,
123233
sequence=sequence,
124234
)
125235

0 commit comments

Comments
 (0)