Skip to content

Commit 2e1d277

Browse files
committed
Fabric: reserve externally provided MASTER_PORT values
1 parent a1250f7 commit 2e1d277

File tree

6 files changed

+356
-40
lines changed

6 files changed

+356
-40
lines changed

src/lightning/fabric/plugins/environments/lightning.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,38 @@ def teardown(self) -> None:
104104
if "WORLD_SIZE" in os.environ:
105105
del os.environ["WORLD_SIZE"]
106106

107+
if self._main_port != -1:
108+
get_port_manager().release_port(self._main_port)
109+
self._main_port = -1
110+
111+
os.environ.pop("MASTER_PORT", None)
112+
os.environ.pop("MASTER_ADDR", None)
113+
107114

108115
def find_free_network_port() -> int:
109116
"""Finds a free port on localhost.
110117
111118
It is useful in single-node training when we don't want to connect to a real main node but have to set the
112119
`MASTER_PORT` environment variable.
113120
114-
This function uses a global port manager to prevent internal race conditions within the test suite.
115121
The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released.
116122
117-
Note:
118-
While this prevents collisions between concurrent Lightning tests, external processes can still
119-
claim the port between allocation and binding. For production use, explicitly set the MASTER_PORT
120-
environment variable.
121-
122123
Returns:
123124
A port number that is reserved and free at the time of allocation
124125
125126
"""
127+
# If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or
128+
# multiprocessing helpers), reserve it through the port manager so no other test reuses the same number.
129+
if "MASTER_PORT" in os.environ:
130+
master_port_str = os.environ["MASTER_PORT"]
131+
try:
132+
existing_port = int(master_port_str)
133+
except ValueError:
134+
pass
135+
else:
136+
port_manager = get_port_manager()
137+
if port_manager.reserve_existing_port(existing_port):
138+
return existing_port
139+
126140
port_manager = get_port_manager()
127141
return port_manager.allocate_port()

src/lightning/fabric/utilities/port_manager.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
import atexit
1717
import socket
1818
import threading
19+
from collections import deque
1920
from collections.abc import Iterator
2021
from contextlib import contextmanager
2122
from typing import Optional
2223

24+
# Maximum number of recently released ports to track before reuse
25+
_RECENTLY_RELEASED_PORTS_MAXLEN = 256
26+
2327

2428
class PortManager:
2529
"""Thread-safe port manager to prevent EADDRINUSE errors.
@@ -33,10 +37,14 @@ class PortManager:
3337
def __init__(self) -> None:
3438
self._lock = threading.Lock()
3539
self._allocated_ports: set[int] = set()
40+
# Recently released ports are kept in a queue to avoid immediate reuse
41+
self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN)
42+
# Counter to vary starting position on each allocation
43+
self._allocation_counter = 0
3644
# Register cleanup to release all ports on exit
3745
atexit.register(self.release_all)
3846

39-
def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 100) -> int:
47+
def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 1000) -> int:
4048
"""Allocate a free port, ensuring it's not already reserved.
4149
4250
Args:
@@ -55,23 +63,42 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
5563
if (
5664
preferred_port is not None
5765
and preferred_port not in self._allocated_ports
66+
and preferred_port not in self._recently_released
5867
and self._is_port_free(preferred_port)
5968
):
6069
self._allocated_ports.add(preferred_port)
6170
return preferred_port
6271

63-
# Try to find a free port
72+
# Ephemeral port range (49152-65535)
73+
# We'll search through this range to find a port that's:
74+
# 1. Not in our allocated set
75+
# 2. Not in our recently_released queue
76+
# 3. Actually free according to the OS
77+
78+
# Use combination of thread ID and allocation counter to vary starting point
79+
# This distributes allocation across the port range and avoids hotspots
80+
port_range_size = 65535 - 49152 + 1
81+
thread_id = threading.get_ident()
82+
self._allocation_counter += 1
83+
start_offset = (hash(thread_id) + self._allocation_counter) % port_range_size
84+
6485
for attempt in range(max_attempts):
65-
port = self._find_free_port()
86+
# Cycle through port range with varied offset
87+
port = 49152 + ((start_offset + attempt) % port_range_size)
6688

67-
# Double-check it's not in our reserved set (shouldn't happen, but be safe)
68-
if port not in self._allocated_ports:
89+
# Skip if in our tracking structures
90+
if port in self._allocated_ports or port in self._recently_released:
91+
continue
92+
93+
# Check if actually free
94+
if self._is_port_free(port):
6995
self._allocated_ports.add(port)
7096
return port
7197

7298
raise RuntimeError(
7399
f"Failed to allocate a free port after {max_attempts} attempts. "
74-
f"Currently allocated ports: {len(self._allocated_ports)}"
100+
f"Currently allocated: {len(self._allocated_ports)}, "
101+
f"recently released: {len(self._recently_released)}"
75102
)
76103

77104
def release_port(self, port: int) -> None:
@@ -82,12 +109,43 @@ def release_port(self, port: int) -> None:
82109
83110
"""
84111
with self._lock:
85-
self._allocated_ports.discard(port)
112+
if port in self._allocated_ports:
113+
self._allocated_ports.remove(port)
114+
# Add to the back of the queue; oldest will be evicted when queue is full
115+
self._recently_released.append(port)
86116

87117
def release_all(self) -> None:
88118
"""Release all allocated ports."""
89119
with self._lock:
90120
self._allocated_ports.clear()
121+
self._recently_released.clear()
122+
123+
def reserve_existing_port(self, port: int) -> bool:
124+
"""Reserve a port that was allocated externally.
125+
126+
Args:
127+
port: The externally assigned port to reserve.
128+
129+
Returns:
130+
True if the port was reserved (or already reserved), False if the port value is invalid.
131+
132+
"""
133+
if port <= 0 or port > 65535:
134+
return False
135+
136+
with self._lock:
137+
if port in self._allocated_ports:
138+
return True
139+
140+
# Remove from recently released queue if present (we're explicitly reserving it)
141+
if port in self._recently_released:
142+
# Create a new deque without this port
143+
self._recently_released = deque(
144+
(p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN
145+
)
146+
147+
self._allocated_ports.add(port)
148+
return True
91149

92150
@contextmanager
93151
def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:
@@ -121,7 +179,8 @@ def _find_free_port() -> int:
121179
122180
"""
123181
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
124-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
182+
# Don't use SO_REUSEADDR - we need to match the behavior of TCPStore
183+
# which binds without it, so ports in TIME_WAIT will be rejected
125184
s.bind(("", 0))
126185
port = s.getsockname()[1]
127186
s.close()
@@ -140,7 +199,8 @@ def _is_port_free(port: int) -> bool:
140199
"""
141200
try:
142201
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
202+
# Don't use SO_REUSEADDR - we need to match the behavior of TCPStore
203+
# which binds without it, so ports in TIME_WAIT will be rejected
144204
s.bind(("", port))
145205
s.close()
146206
return True

tests/tests_fabric/conftest.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,20 @@ def teardown_process_group():
8282

8383
from lightning.fabric.utilities.port_manager import get_port_manager
8484

85-
# Record the port used in this test (if any)
86-
port_to_release = None
87-
if "MASTER_PORT" in os.environ:
88-
with contextlib.suppress(ValueError, KeyError):
89-
port_to_release = int(os.environ["MASTER_PORT"])
90-
9185
yield
9286

9387
# Clean up distributed connection
9488
_destroy_dist_connection()
9589

96-
# Release the port from the manager so it can be reused
97-
if port_to_release is not None:
98-
get_port_manager().release_port(port_to_release)
90+
manager = get_port_manager()
91+
92+
# If a process group created or updated MASTER_PORT during the test, reserve it and then clear it
93+
if "MASTER_PORT" in os.environ:
94+
with contextlib.suppress(ValueError):
95+
port = int(os.environ["MASTER_PORT"])
96+
manager.reserve_existing_port(port)
97+
manager.release_port(port)
98+
os.environ.pop("MASTER_PORT", None)
9999

100100

101101
@pytest.fixture(autouse=True)

tests/tests_fabric/plugins/environments/test_lightning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818

1919
from lightning.fabric.plugins.environments import LightningEnvironment
20+
from lightning.fabric.utilities.port_manager import get_port_manager
2021

2122

2223
@mock.patch.dict(os.environ, {}, clear=True)
@@ -84,5 +85,18 @@ def test_teardown():
8485
assert "WORLD_SIZE" not in os.environ
8586

8687

88+
@mock.patch.dict(os.environ, {}, clear=True)
89+
def test_teardown_releases_port_and_env():
90+
env = LightningEnvironment()
91+
port = env.main_port
92+
assert port in get_port_manager()._allocated_ports
93+
94+
env.teardown()
95+
96+
assert port not in get_port_manager()._allocated_ports
97+
assert "MASTER_PORT" not in os.environ
98+
assert "MASTER_ADDR" not in os.environ
99+
100+
87101
def test_detect():
88102
assert LightningEnvironment.detect()

0 commit comments

Comments
 (0)