File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change 7272 if : contains(matrix.alias, 'distributed')
7373 run : |
7474 set -euxo pipefail
75+ GPU_COUNT=$(nvidia-smi -L | wc -l)
76+ if [ "$GPU_COUNT" -ne 4 ]; then
77+ echo "Error: Expected 4 GPUs but found $GPU_COUNT"
78+ exit 1
79+ fi
7580 curl -L https://raw.githubusercontent.com/pytorch/pytorch/main/.ci/docker/common/install_cuda.sh -o install_cuda.sh
7681 chmod +x install_cuda.sh
7782 source install_cuda.sh
@@ -155,7 +160,7 @@ jobs:
155160 # -rf: print failed tests
156161 # --timeout: max allowed time for each test
157162 TEST_PATH=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "test/test_examples_dist.py" || echo ".")
158- EXTRA_FLAGS=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "" || echo "--ignore=test/test_examples_dist.py")
163+ EXTRA_FLAGS=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "-vs " || echo "--ignore=test/test_examples_dist.py")
159164 pytest -rf --timeout=60 $EXTRA_FLAGS $TEST_PATH
160165
161166 test-notebooks :
Original file line number Diff line number Diff line change 44import torch .distributed as dist
55import torch .distributed ._symmetric_memory as symm_mem
66from torch .testing ._internal .common_distributed import MultiProcessTestCase
7- from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
87from torch .testing ._internal .common_utils import instantiate_parametrized_tests
98from torch .testing ._internal .common_utils import run_tests
109
@@ -43,7 +42,6 @@ def _init_process(self):
4342 )
4443 torch .manual_seed (42 + self .rank )
4544
46- @skip_if_lt_x_gpu (4 )
4745 def test_all_gather_matmul (self ):
4846 self ._init_process ()
4947
@@ -100,7 +98,6 @@ def test_all_gather_matmul(self):
10098 torch .cuda .current_stream ().wait_stream (backend_stream )
10199 dist .destroy_process_group ()
102100
103- @skip_if_lt_x_gpu (4 )
104101 def test_all_reduce (self ):
105102 self ._init_process ()
106103
You can’t perform that action at this time.
0 commit comments