Skip to content

Commit 2a07f0a

Browse files
Fix ppu (#6489)
1 parent 10d9096 commit 2a07f0a

File tree

7 files changed

+30
-6
lines changed

7 files changed

+30
-6
lines changed

docs/source/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,4 @@ qwen2_5_omni除了包含qwen2_5_vl和qwen2_audio的模型特定参数外,还
843843
- VLLM_USE_V1: 用于切换vLLM使用V0/V1版本。
844844
- SWIFT_TIMEOUT: (ms-swift>=3.10) 若多模态数据集中存在图像URL,该参数用于控制获取图片的timeout,默认为20s。
845845
- ROOT_IMAGE_DIR: (ms-swift>=3.8) 图像(多模态)资源的根目录。通过设置该参数,可以在数据集中使用相对于 `ROOT_IMAGE_DIR` 的相对路径。默认情况下,是相对于运行目录的相对路径。
846+
- SWIFT_SINGLE_DEVICE_MODE: (ms-swift>=3.10) 单设备模式,在此模式下,所有进程只能看到一个设备,目前用于兼容PPU设备

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,3 +868,4 @@ The meanings of the following parameters can be found in the example code [here]
868868
- VLLM_USE_V1: Used to switch between V0 and V1 versions of vLLM.
869869
- SWIFT_TIMEOUT: (ms-swift >= 3.10) If the multimodal dataset contains image URLs, this parameter controls the timeout for fetching images, defaulting to 20 seconds.
870870
- ROOT_IMAGE_DIR: (ms-swift>=3.8) The root directory for image (multimodal) resources. By setting this parameter, relative paths in the dataset can be interpreted relative to `ROOT_IMAGE_DIR`. By default, paths are relative to the current working directory.
871+
- SWIFT_SINGLE_DEVICE_MODE: (ms-swift>=3.10) Single device mode. In this mode, all processes can only see one device. Currently used for compatibility with PPU devices.

swift/cli/pt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from swift.llm import pt_main
32

43
if __name__ == '__main__':
4+
from swift.cli.utils import try_use_single_device_mode
5+
try_use_single_device_mode()
6+
from swift.llm import pt_main
57
pt_main()

swift/cli/rlhf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from swift.llm import rlhf_main
32

43
if __name__ == '__main__':
4+
from swift.cli.utils import try_use_single_device_mode
5+
try_use_single_device_mode()
6+
from swift.llm import rlhf_main
57
rlhf_main()

swift/cli/sft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def try_init_unsloth():
1111

1212

1313
if __name__ == '__main__':
14+
from swift.cli.utils import try_use_single_device_mode
15+
try_use_single_device_mode()
1416
try_init_unsloth()
1517
from swift.ray import try_init_ray
1618
try_init_ray()

swift/cli/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
3+
4+
def try_use_single_device_mode():
5+
if os.environ.get('SWIFT_SINGLE_DEVICE_MODE', '0') == '1':
6+
visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
7+
local_rank = os.environ.get('LOCAL_RANK')
8+
if local_rank is None or not visible_devices:
9+
return
10+
visible_devices = visible_devices.split(',')
11+
visible_device = visible_devices[int(local_rank)]
12+
os.environ['CUDA_VISIBLE_DEVICES'] = str(visible_device)
13+
os.environ['LOCAL_RANK'] = '0'

swift/utils/env.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ def is_mp() -> bool:
6666
from swift.utils import get_device_count
6767
n_gpu = get_device_count()
6868
local_world_size = get_dist_setting()[3]
69-
assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}'
70-
if n_gpu // local_world_size >= 2:
71-
return True
72-
return False
69+
if os.environ.get('SWIFT_SINGLE_DEVICE_MODE', '0') != '1':
70+
assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}'
71+
if n_gpu // local_world_size >= 2:
72+
return True
73+
return False
74+
else:
75+
return False
7376

7477

7578
def is_mp_ddp() -> bool:

0 commit comments

Comments
 (0)