-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.5.x
Description
Bug description
ddp_fork in tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py failing on macos14-3.12-2.8
May have to undo change introduced here
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
____________________ test_memory_sharing_disabled[ddp_fork] ____________________
strategy = 'ddp_fork'
@RunIf(skip_windows=True)
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork"])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
tensor = torch.rand(4)
model = SimpleModel()
assert not tensor.is_shared()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
fabric = Fabric(accelerator="cpu", devices=2, strategy=strategy)
> fabric.launch(_test_memory_sharing_disabled, tensor, model)
strategies/launchers/test_multiprocessing_integration.py:45:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.venv/lib/python3.12/site-packages/lightning_fabric/fabric.py:986: in launch
return self._wrap_and_launch(function, self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../.venv/lib/python3.12/site-packages/lightning_fabric/fabric.py:1071: in _wrap_and_launch
return launcher.launch(to_run, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../.venv/lib/python3.12/site-packages/lightning_fabric/strategies/launchers/multiprocessing.py:117: in launch
mp.start_processes(
../../.venv/lib/python3.12/site-packages/torch/multiprocessing/spawn.py:296: in start_processes
while not context.join():
^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <torch.multiprocessing.spawn.ProcessContext object at 0x11636f5c0>
timeout = None, grace_period = None
def join(
self, timeout: Optional[float] = None, grace_period: Optional[float] = None
):
r"""Join one or more processes within spawn context.
Attempt to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes (optionally with a grace period)
and raises an exception with the cause of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Args:
timeout (float): Wait this long (in seconds) before giving up on waiting.
grace_period (float): When any processes fail, wait this long (in seconds)
for others to shutdown gracefully before terminating them. If they
still don't exit, wait another grace period before killing them.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# An error occurred. Clean-up all processes before returning.
# First, allow a grace period for processes to shutdown themselves.
if grace_period is not None:
self._join_procs_with_timeout(grace_period)
# Then, terminate processes that are still alive. Try SIGTERM first.
for process in self.processes:
if process.is_alive():
log.warning("Terminating process %s via signal SIGTERM", process.pid)
process.terminate()
# Try SIGKILL if the process isn't going down after another grace_period.
# The reason is related to python signal handling is limited
# to main thread and if that is in c/c++ land and stuck it won't
# to handle it. We have seen processes getting stuck not handling
# SIGTERM for the above reason.
self._join_procs_with_timeout(30 if grace_period is None else grace_period)
for process in self.processes:
if process.is_alive():
log.warning(
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
process.pid,
)
process.kill()
process.join()
# The file will only be created if the process crashed.
failed_process = self.processes[error_index]
if not os.access(self.error_files[error_index], os.R_OK):
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
try:
name = signal.Signals(-exitcode).name
except ValueError:
name = f"<Unknown signal {-exitcode}>"
> raise ProcessExitedException(
f"process {error_index:d} terminated with signal {name}",
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
signal_name=name,
)
E torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGABRT
../../.venv/lib/python3.12/site-packages/torch/multiprocessing/spawn.py:196: ProcessExitedException
----------------------------- Captured stderr call -----------------------------
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
objc[3067]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3067]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
objc[3068]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3068]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
----------------------------- Captured stderr call -----------------------------
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
objc[3069]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3069]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
objc[3070]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3070]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
----------------------------- Captured stderr call -----------------------------
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
objc[3072]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3071]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3072]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
objc[3071]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
W1106 20:40:51.637000 2837 torch/multiprocessing/spawn.py:169] Terminating process 3072 via signal SIGTERM
----------------------------- Captured stderr call -----------------------------
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
objc[3074]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3074]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
objc[3073]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called.
objc[3073]: +[__NSCFConstantString initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
W1106 20:40:51.761000 2837 torch/multiprocessing/spawn.py:169] Terminating process 3073 via signal SIGTERM
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.5.x