Skip to content

Commit 49cc37d

Browse files
committed
Fabric: reserve externally provided MASTER_PORT values
1 parent a1250f7 commit 49cc37d

File tree

5 files changed

+242
-14
lines changed

5 files changed

+242
-14
lines changed

src/lightning/fabric/plugins/environments/lightning.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,38 @@ def teardown(self) -> None:
104104
if "WORLD_SIZE" in os.environ:
105105
del os.environ["WORLD_SIZE"]
106106

107+
if self._main_port != -1:
108+
get_port_manager().release_port(self._main_port)
109+
self._main_port = -1
110+
111+
os.environ.pop("MASTER_PORT", None)
112+
os.environ.pop("MASTER_ADDR", None)
113+
107114

108115
def find_free_network_port() -> int:
109116
"""Finds a free port on localhost.
110117
111118
It is useful in single-node training when we don't want to connect to a real main node but have to set the
112119
`MASTER_PORT` environment variable.
113120
114-
This function uses a global port manager to prevent internal race conditions within the test suite.
115121
The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released.
116122
117-
Note:
118-
While this prevents collisions between concurrent Lightning tests, external processes can still
119-
claim the port between allocation and binding. For production use, explicitly set the MASTER_PORT
120-
environment variable.
121-
122123
Returns:
123124
A port number that is reserved and free at the time of allocation
124125
125126
"""
127+
# If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or
128+
# multiprocessing helpers), reserve it through the port manager so no other test reuses the same number.
129+
if "MASTER_PORT" in os.environ:
130+
master_port_str = os.environ["MASTER_PORT"]
131+
try:
132+
existing_port = int(master_port_str)
133+
except ValueError:
134+
pass
135+
else:
136+
port_manager = get_port_manager()
137+
if port_manager.reserve_existing_port(existing_port):
138+
return existing_port
139+
126140
port_manager = get_port_manager()
127141
return port_manager.allocate_port()

src/lightning/fabric/utilities/port_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,26 @@ def release_all(self) -> None:
8989
with self._lock:
9090
self._allocated_ports.clear()
9191

92+
def reserve_existing_port(self, port: int) -> bool:
93+
"""Reserve a port that was allocated externally.
94+
95+
Args:
96+
port: The externally assigned port to reserve.
97+
98+
Returns:
99+
True if the port was reserved (or already reserved), False if the port value is invalid.
100+
101+
"""
102+
if port <= 0 or port > 65535:
103+
return False
104+
105+
with self._lock:
106+
if port in self._allocated_ports:
107+
return True
108+
109+
self._allocated_ports.add(port)
110+
return True
111+
92112
@contextmanager
93113
def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:
94114
"""Context manager for automatic port cleanup.

tests/tests_fabric/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,16 @@ def teardown_process_group():
9494
_destroy_dist_connection()
9595

9696
# Release the port from the manager so it can be reused
97+
manager = get_port_manager()
98+
9799
if port_to_release is not None:
98-
get_port_manager().release_port(port_to_release)
100+
manager.release_port(port_to_release)
101+
102+
# If a process group created or updated MASTER_PORT during the test, make sure we track it and then clear it.
103+
if "MASTER_PORT" in os.environ:
104+
with contextlib.suppress(ValueError):
105+
manager.reserve_existing_port(int(os.environ["MASTER_PORT"]))
106+
os.environ.pop("MASTER_PORT", None)
99107

100108

101109
@pytest.fixture(autouse=True)

tests/tests_fabric/plugins/environments/test_lightning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818

1919
from lightning.fabric.plugins.environments import LightningEnvironment
20+
from lightning.fabric.utilities.port_manager import get_port_manager
2021

2122

2223
@mock.patch.dict(os.environ, {}, clear=True)
@@ -84,5 +85,18 @@ def test_teardown():
8485
assert "WORLD_SIZE" not in os.environ
8586

8687

88+
@mock.patch.dict(os.environ, {}, clear=True)
89+
def test_teardown_releases_port_and_env():
90+
env = LightningEnvironment()
91+
port = env.main_port
92+
assert port in get_port_manager()._allocated_ports
93+
94+
env.teardown()
95+
96+
assert port not in get_port_manager()._allocated_ports
97+
assert "MASTER_PORT" not in os.environ
98+
assert "MASTER_ADDR" not in os.environ
99+
100+
87101
def test_detect():
88102
assert LightningEnvironment.detect()

tests/tests_fabric/utilities/test_port_manager.py

Lines changed: 179 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,45 @@
1818
import threading
1919
from collections import Counter
2020

21+
import pytest
22+
2123
from lightning.fabric.plugins.environments.lightning import find_free_network_port
2224
from lightning.fabric.utilities.port_manager import PortManager, get_port_manager
2325

26+
# =============================================================================
27+
# Fixtures
28+
# =============================================================================
29+
30+
31+
@pytest.fixture
32+
def with_master_port():
33+
"""Fixture that sets MASTER_PORT before test runs, for conftest coverage."""
34+
port = find_free_network_port()
35+
previous_value = os.environ.get("MASTER_PORT")
36+
os.environ["MASTER_PORT"] = str(port)
37+
try:
38+
yield port
39+
finally:
40+
if previous_value is None:
41+
os.environ.pop("MASTER_PORT", None)
42+
else:
43+
os.environ["MASTER_PORT"] = previous_value
44+
45+
46+
@pytest.fixture
47+
def with_invalid_master_port():
48+
"""Fixture that sets invalid MASTER_PORT to test error handling."""
49+
previous_value = os.environ.get("MASTER_PORT")
50+
os.environ["MASTER_PORT"] = "not_a_valid_port_number"
51+
try:
52+
yield
53+
finally:
54+
if previous_value is None:
55+
os.environ.pop("MASTER_PORT", None)
56+
else:
57+
os.environ["MASTER_PORT"] = previous_value
58+
59+
2460
# =============================================================================
2561
# Unit Tests for PortManager
2662
# =============================================================================
@@ -135,12 +171,19 @@ def test_port_manager_allocation_failure():
135171
"""Test that PortManager raises error when unable to allocate after max attempts."""
136172
manager = PortManager()
137173

138-
# This is hard to test without actually exhausting ports, but we can test
139-
# the error path by mocking or just ensure the code path exists
140-
# For now, just verify that max_attempts parameter exists
141-
port = manager.allocate_port(max_attempts=1)
174+
# Pre-allocate a large number of ports to make it harder to find a free one
175+
# Then try with max_attempts=1 which should fail quickly
176+
allocated_ports = [manager.allocate_port() for _ in range(50)]
177+
178+
# Test that it can still allocate with enough attempts
179+
port = manager.allocate_port(max_attempts=100)
142180
assert port >= 1024
143181

182+
# Clean up
183+
for p in allocated_ports:
184+
manager.release_port(p)
185+
manager.release_port(port)
186+
144187

145188
def test_port_manager_prevents_reallocation():
146189
"""Test that a port won't be allocated twice until released."""
@@ -160,6 +203,10 @@ def test_port_manager_prevents_reallocation():
160203
manager.release_port(port1)
161204
assert port1 not in manager._allocated_ports
162205

206+
# Clean up
207+
for port in more_ports:
208+
manager.release_port(port)
209+
163210

164211
def test_get_port_manager_singleton():
165212
"""Test that get_port_manager returns the same instance."""
@@ -173,6 +220,9 @@ def test_get_port_manager_singleton():
173220
port = manager1.allocate_port()
174221
assert port in manager2._allocated_ports
175222

223+
# Clean up
224+
manager1.release_port(port)
225+
176226

177227
def test_get_port_manager_thread_safe_singleton():
178228
"""Test that get_port_manager creates singleton safely across threads."""
@@ -232,16 +282,21 @@ def test_port_manager_concurrent_allocation_and_release():
232282
manager = PortManager()
233283
ports = []
234284
lock = threading.Lock()
285+
active_ports: set[int] = set()
235286

236287
def allocate_and_release():
237288
for _ in range(5):
238289
# Allocate a port
239290
port = manager.allocate_port()
240291
with lock:
292+
assert port not in active_ports, "Port allocated concurrently"
293+
active_ports.add(port)
241294
ports.append(port)
242295

243296
# Release it immediately
244297
manager.release_port(port)
298+
with lock:
299+
active_ports.remove(port)
245300

246301
# Run multiple threads
247302
threads = [threading.Thread(target=allocate_and_release) for _ in range(10)]
@@ -254,9 +309,6 @@ def allocate_and_release():
254309
# Should have allocated 50 ports total (10 threads × 5 ports)
255310
assert len(ports) == 50
256311

257-
# All should be unique (no port allocated twice before being released)
258-
assert len(set(ports)) == 50, "Same port was allocated to multiple threads before release"
259-
260312
# After all releases, manager should have no ports allocated
261313
assert len(manager._allocated_ports) == 0
262314

@@ -280,6 +332,42 @@ def test_port_manager_atexit_cleanup():
280332
assert len(manager._allocated_ports) == 0
281333

282334

335+
def test_port_manager_reserve_existing_port_free():
336+
"""reserve_existing_port should succeed for free ports and track them."""
337+
manager = PortManager()
338+
339+
port = manager._find_free_port()
340+
assert manager.reserve_existing_port(port)
341+
assert port in manager._allocated_ports
342+
343+
# Second call should succeed but not duplicate
344+
assert manager.reserve_existing_port(port)
345+
assert len(manager._allocated_ports) == 1
346+
347+
348+
def test_port_manager_reserve_existing_port_invalid_value():
349+
"""reserve_existing_port should reject invalid port numbers."""
350+
manager = PortManager()
351+
352+
assert not manager.reserve_existing_port(0)
353+
assert not manager.reserve_existing_port(-1)
354+
assert not manager.reserve_existing_port(70000)
355+
356+
357+
def test_port_manager_reserve_existing_port_after_release():
358+
"""Ports released from sockets should become reservable."""
359+
manager = PortManager()
360+
361+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
362+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
363+
s.bind(("", 0))
364+
reusable_port = s.getsockname()[1]
365+
s.close()
366+
367+
assert manager.reserve_existing_port(reusable_port)
368+
assert reusable_port in manager._allocated_ports
369+
370+
283371
def test_port_manager_context_manager():
284372
"""Test that context manager automatically releases ports."""
285373
manager = PortManager()
@@ -430,6 +518,44 @@ def test_port_allocation_simulates_distributed_test_lifecycle():
430518
assert len(manager._allocated_ports) == initial_count
431519

432520

521+
def test_conftest_cleanup_with_master_port_set(with_master_port):
522+
"""Test conftest cleanup when MASTER_PORT is set before test starts.
523+
524+
This test uses a fixture to set MASTER_PORT before the test runs, allowing the conftest teardown_process_group
525+
fixture to capture and clean it up. This ensures the conftest cleanup code is covered.
526+
527+
"""
528+
manager = get_port_manager()
529+
port = with_master_port # Port was set by fixture
530+
531+
# Verify port is allocated
532+
assert port in manager._allocated_ports
533+
assert os.environ.get("MASTER_PORT") == str(port)
534+
535+
# Leave MASTER_PORT set - conftest teardown will clean it up
536+
# After this test, teardown_process_group will:
537+
# 1. Detect MASTER_PORT in os.environ (line captured before yield)
538+
# 2. Call get_port_manager().release_port(port)
539+
# 3. Port gets released back to manager
540+
541+
542+
def test_conftest_handles_invalid_master_port(with_invalid_master_port):
543+
"""Test conftest handles invalid MASTER_PORT gracefully.
544+
545+
This exercises the contextlib.suppress(ValueError, KeyError) path in the conftest teardown_process_group fixture.
546+
547+
"""
548+
# Fixture set MASTER_PORT to "not_a_valid_port_number"
549+
# The conftest will try to parse it: int(os.environ["MASTER_PORT"])
550+
# This will raise ValueError, which should be caught by contextlib.suppress
551+
552+
# Verify the invalid value is set
553+
assert os.environ.get("MASTER_PORT") == "not_a_valid_port_number"
554+
555+
# This test just needs to complete without crashing
556+
# The conftest teardown will handle the ValueError gracefully
557+
558+
433559
def test_multiple_tests_can_reuse_ports_after_release():
434560
"""Test that ports can be reused after being released."""
435561
manager = get_port_manager()
@@ -521,3 +647,49 @@ def test_port_manager_survives_multiple_test_sessions():
521647
# Clean up
522648
for port in session2_ports + session3_ports:
523649
manager.release_port(port)
650+
651+
652+
def test_port_manager_allocation_runtime_error():
653+
"""Test that allocation fails gracefully when max_attempts is exhausted."""
654+
manager = PortManager()
655+
656+
# Mock the _find_free_port to always return a port that's already allocated
657+
# This will cause max_attempts to be exhausted
658+
allocated_port = manager.allocate_port()
659+
660+
# Save original method
661+
original_find = manager._find_free_port
662+
663+
# Make _find_free_port always return the already-allocated port
664+
def always_return_allocated():
665+
return allocated_port
666+
667+
manager._find_free_port = always_return_allocated
668+
669+
# This should raise RuntimeError after max_attempts
670+
with pytest.raises(RuntimeError, match="Failed to allocate a free port after .* attempts"):
671+
manager.allocate_port(max_attempts=5)
672+
673+
# Restore original method and clean up
674+
manager._find_free_port = original_find
675+
manager.release_port(allocated_port)
676+
677+
678+
def test_find_free_network_port_respects_existing_master_port(with_master_port):
679+
"""find_free_network_port should reuse externally provided MASTER_PORT."""
680+
manager = get_port_manager()
681+
port = with_master_port
682+
683+
returned_port = find_free_network_port()
684+
assert returned_port == port
685+
assert port in manager._allocated_ports
686+
687+
688+
def test_find_free_network_port_handles_invalid_master_port(with_invalid_master_port):
689+
"""Invalid MASTER_PORT values should fall back to allocating a fresh port."""
690+
manager = get_port_manager()
691+
692+
returned_port = find_free_network_port()
693+
assert isinstance(returned_port, int)
694+
assert returned_port in manager._allocated_ports
695+
assert returned_port != "not_a_valid_port_number"

0 commit comments

Comments
 (0)