Skip to content

Commit 69819dc

Browse files
committed
Fabric: reserve externally provided MASTER_PORT values
1 parent a1250f7 commit 69819dc

File tree

6 files changed

+335
-17
lines changed

6 files changed

+335
-17
lines changed

src/lightning/fabric/plugins/environments/lightning.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,38 @@ def teardown(self) -> None:
104104
if "WORLD_SIZE" in os.environ:
105105
del os.environ["WORLD_SIZE"]
106106

107+
if self._main_port != -1:
108+
get_port_manager().release_port(self._main_port)
109+
self._main_port = -1
110+
111+
os.environ.pop("MASTER_PORT", None)
112+
os.environ.pop("MASTER_ADDR", None)
113+
107114

108115
def find_free_network_port() -> int:
109116
"""Finds a free port on localhost.
110117
111118
It is useful in single-node training when we don't want to connect to a real main node but have to set the
112119
`MASTER_PORT` environment variable.
113120
114-
This function uses a global port manager to prevent internal race conditions within the test suite.
115121
The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released.
116122
117-
Note:
118-
While this prevents collisions between concurrent Lightning tests, external processes can still
119-
claim the port between allocation and binding. For production use, explicitly set the MASTER_PORT
120-
environment variable.
121-
122123
Returns:
123124
A port number that is reserved and free at the time of allocation
124125
125126
"""
127+
# If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or
128+
# multiprocessing helpers), reserve it through the port manager so no other test reuses the same number.
129+
if "MASTER_PORT" in os.environ:
130+
master_port_str = os.environ["MASTER_PORT"]
131+
try:
132+
existing_port = int(master_port_str)
133+
except ValueError:
134+
pass
135+
else:
136+
port_manager = get_port_manager()
137+
if port_manager.reserve_existing_port(existing_port):
138+
return existing_port
139+
126140
port_manager = get_port_manager()
127141
return port_manager.allocate_port()

src/lightning/fabric/utilities/port_manager.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
import atexit
1717
import socket
1818
import threading
19+
from collections import deque
1920
from collections.abc import Iterator
2021
from contextlib import contextmanager
2122
from typing import Optional
2223

24+
# Maximum number of recently released ports to track before reuse
25+
_RECENTLY_RELEASED_PORTS_MAXLEN = 256
26+
2327

2428
class PortManager:
2529
"""Thread-safe port manager to prevent EADDRINUSE errors.
@@ -33,6 +37,8 @@ class PortManager:
3337
def __init__(self) -> None:
3438
self._lock = threading.Lock()
3539
self._allocated_ports: set[int] = set()
40+
# Recently released ports are kept in a queue to avoid immediate reuse
41+
self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN)
3642
# Register cleanup to release all ports on exit
3743
atexit.register(self.release_all)
3844

@@ -55,6 +61,7 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
5561
if (
5662
preferred_port is not None
5763
and preferred_port not in self._allocated_ports
64+
and preferred_port not in self._recently_released
5865
and self._is_port_free(preferred_port)
5966
):
6067
self._allocated_ports.add(preferred_port)
@@ -64,7 +71,10 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
6471
for attempt in range(max_attempts):
6572
port = self._find_free_port()
6673

67-
# Double-check it's not in our reserved set (shouldn't happen, but be safe)
74+
# Skip ports that were recently released to avoid TIME_WAIT conflicts
75+
if port in self._recently_released:
76+
continue
77+
6878
if port not in self._allocated_ports:
6979
self._allocated_ports.add(port)
7080
return port
@@ -82,12 +92,43 @@ def release_port(self, port: int) -> None:
8292
8393
"""
8494
with self._lock:
85-
self._allocated_ports.discard(port)
95+
if port in self._allocated_ports:
96+
self._allocated_ports.remove(port)
97+
# Add to the back of the queue; oldest will be evicted when queue is full
98+
self._recently_released.append(port)
8699

87100
def release_all(self) -> None:
88101
"""Release all allocated ports."""
89102
with self._lock:
90103
self._allocated_ports.clear()
104+
self._recently_released.clear()
105+
106+
def reserve_existing_port(self, port: int) -> bool:
107+
"""Reserve a port that was allocated externally.
108+
109+
Args:
110+
port: The externally assigned port to reserve.
111+
112+
Returns:
113+
True if the port was reserved (or already reserved), False if the port value is invalid.
114+
115+
"""
116+
if port <= 0 or port > 65535:
117+
return False
118+
119+
with self._lock:
120+
if port in self._allocated_ports:
121+
return True
122+
123+
# Remove from recently released queue if present (we're explicitly reserving it)
124+
if port in self._recently_released:
125+
# Create a new deque without this port
126+
self._recently_released = deque(
127+
(p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN
128+
)
129+
130+
self._allocated_ports.add(port)
131+
return True
91132

92133
@contextmanager
93134
def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:

tests/tests_fabric/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,16 @@ def teardown_process_group():
9494
_destroy_dist_connection()
9595

9696
# Release the port from the manager so it can be reused
97+
manager = get_port_manager()
98+
9799
if port_to_release is not None:
98-
get_port_manager().release_port(port_to_release)
100+
manager.release_port(port_to_release)
101+
102+
# If a process group created or updated MASTER_PORT during the test, make sure we track it and then clear it.
103+
if "MASTER_PORT" in os.environ:
104+
with contextlib.suppress(ValueError):
105+
manager.reserve_existing_port(int(os.environ["MASTER_PORT"]))
106+
os.environ.pop("MASTER_PORT", None)
99107

100108

101109
@pytest.fixture(autouse=True)

tests/tests_fabric/plugins/environments/test_lightning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818

1919
from lightning.fabric.plugins.environments import LightningEnvironment
20+
from lightning.fabric.utilities.port_manager import get_port_manager
2021

2122

2223
@mock.patch.dict(os.environ, {}, clear=True)
@@ -84,5 +85,18 @@ def test_teardown():
8485
assert "WORLD_SIZE" not in os.environ
8586

8687

88+
@mock.patch.dict(os.environ, {}, clear=True)
89+
def test_teardown_releases_port_and_env():
90+
env = LightningEnvironment()
91+
port = env.main_port
92+
assert port in get_port_manager()._allocated_ports
93+
94+
env.teardown()
95+
96+
assert port not in get_port_manager()._allocated_ports
97+
assert "MASTER_PORT" not in os.environ
98+
assert "MASTER_ADDR" not in os.environ
99+
100+
87101
def test_detect():
88102
assert LightningEnvironment.detect()

0 commit comments

Comments
 (0)