Skip to content
Merged

Fix ppu #6489

Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion swift/cli/pt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import pt_main

if __name__ == '__main__':
from swift.cli.utils import fix_ppu
fix_ppu()
from swift.llm import pt_main
pt_main()
4 changes: 3 additions & 1 deletion swift/cli/rlhf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import rlhf_main

if __name__ == '__main__':
from swift.cli.utils import fix_ppu
fix_ppu()
from swift.llm import rlhf_main
rlhf_main()
2 changes: 2 additions & 0 deletions swift/cli/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def try_init_unsloth():


if __name__ == '__main__':
from swift.cli.utils import fix_ppu
fix_ppu()
try_init_unsloth()
from swift.ray import try_init_ray
try_init_ray()
Expand Down
18 changes: 18 additions & 0 deletions swift/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def is_ppu():
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)

if result.returncode == 0:
output = result.stdout
return 'PPU-' in output
else:
return False


def fix_ppu():
if is_ppu():
import os
visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
visible_device = visible_devices[int(os.environ['LOCAL_RANK'])]
os.environ['CUDA_VISIBLE_DEVICES'] = str(visible_device)
os.environ['LOCAL_RANK'] = '0'
23 changes: 19 additions & 4 deletions swift/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,30 @@ def is_dist():
return rank >= 0 and local_rank >= 0


def is_ppu():
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)

if result.returncode == 0:
output = result.stdout
return 'PPU-' in output
else:
return False


def is_mp() -> bool:

from swift.utils import get_device_count
n_gpu = get_device_count()
local_world_size = get_dist_setting()[3]
assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}'
if n_gpu // local_world_size >= 2:
return True
return False
if not is_ppu():
assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}'
if n_gpu // local_world_size >= 2:
return True
return False
else:
# We do not support mp for PPU
return False


def is_mp_ddp() -> bool:
Expand Down
Loading