Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 146 additions & 2 deletions swift/plugin/callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import time

import numpy as np
import torch
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments

from swift.utils import get_logger
from swift.utils import get_current_device, get_device_count, get_logger

logger = get_logger()

Expand All @@ -28,6 +31,147 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer
control.should_training_stop = True


extra_callbacks = []
class PerfMetricsLogCallback(TrainerCallback):
"""An callback for perf metrics (MFU etc) log implementation"""

def __init__(self):
self.start_time = None
self.device_tflops = None
self.elapsed = 0.0
self.step_start_time = None

def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
from swift.utils import get_env_args

# Top priority. Specify by ENV
tflops = get_env_args('DEVICE_TFLOPS', int, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DEVICE_TFLOPS is parsed as an integer, which might be too restrictive as TFLOPS values are often floating-point numbers (e.g., from the estimation or lookup table). Using float would be more appropriate and consistent.

Suggested change
tflops = get_env_args('DEVICE_TFLOPS', int, None)
tflops = get_env_args('DEVICE_TFLOPS', float, None)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems float is correct.

device_count = max(get_device_count(), 1)
if tflops is not None:
logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]")
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
Comment on lines +50 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _retrieve_flops_from_map function provides an efficient way to get device TFLOPS from a lookup table. It should be called as a fallback before running the performance estimation test. This avoids running a potentially time-consuming benchmark if the device's performance is already known.

Suggested change
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
else:
device = torch.device(get_current_device())
tflops = self._retrieve_flops_from_map(device)
if tflops is not None:
device_name = torch.cuda.get_device_name(device) if device.type == 'cuda' else str(device)
logger.info(f'Retrieved TFLOPS from lookup table for {device_name}: {tflops} TFLOPS')
else:
# Run an estimating test.
dtype = kwargs.get('model').dtype
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')

# TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables.

self.device_tflops = tflops * device_count

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.step_start_time = time.time()

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.elapsed += time.time() - self.step_start_time

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.start_time = time.time()

def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0)
actual_flops = total_flos / self.elapsed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential ZeroDivisionError here if self.elapsed is 0. This could happen if on_log is called before on_step_end has completed. It's safer to add a guard for this case.

Suggested change
actual_flops = total_flos / self.elapsed
actual_flops = total_flos / self.elapsed if self.elapsed > 0 else 0

theoretical_max_flops = self.device_tflops * 1e12
mfu = actual_flops / theoretical_max_flops
logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]')
logs['MFU'] = round(mfu, 6)

@staticmethod
def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192):
# 默认矩阵规模
shape = (dim, dim)
backend = device.type
if backend == 'npu':
import torch_npu

# 创建矩阵
a = torch.randn(*shape, device=device, dtype=dtype)
b = torch.randn(*shape, device=device, dtype=dtype)

# 预热
for _ in range(5):
c = torch.matmul(a, b)
if backend == 'cuda':
torch.cuda.synchronize(device)
elif backend == 'npu':
torch.npu.synchronize(device)
elif backend == 'cpu':
torch.cpu.synchronize(device)

# 进行测试
start = time.time()
for _ in range(repeats):
c = torch.matmul(a, b)
if backend == 'cuda':
torch.cuda.synchronize(device)
elif backend == 'npu':
torch.npu.synchronize(device)
elif backend == 'cpu':
torch.cpu.synchronize(device)
end = time.time()
total_time = end - start
avg_time = total_time / repeats

# 若测试时间过短,调整循环次数并重新测试
if total_time < 3:
repeats = int(6 / avg_time)
start = time.time()
for _ in range(repeats):
c = torch.matmul(a, b)
if backend == 'cuda':
torch.cuda.synchronize(device)
elif backend == 'npu':
torch.npu.synchronize(device)
elif backend == 'cpu':
torch.cpu.synchronize(device)
end = time.time()
total_time = end - start
avg_time = total_time / repeats

del a, b, c
if backend == 'cuda':
torch.cuda.empty_cache()
elif backend == 'npu':
torch.npu.empty_cache()

tflops = (2 * dim**3 / avg_time) / 1e12
print(f'[设备 {device}] 测试总耗时:{total_time:.4f}s,平均耗时: {avg_time:.4f} s,dtype:{dtype},性能: {tflops:.4f} TFLOPS')

return tflops

@staticmethod
def _retrieve_flops_from_map(device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this function is not being used. Just curious, what's the reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More data should be collected, and hopefully the function _retrieve_flops_from_map can become a more efficient and accurate way to obtain device TFLOPS. This avoids running a potentially time-consuming benchmark if the device's performance is already known.

"""Retrieve theoretical FLOPS from Map. """

device_name = device.get_device_name()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the get_device_name function?

flops = None
for name, value in device_flops_map:
if name in device_name:
flops = value
break

return flops


device_flops_map = {
'GB200': 2.5e15,
'B200': 2.25e15,
'MI300X': 1336e12,
'H100': 312e12,
'H800': 312e12,
'H200': 989e12,
'A100': 312e12,
'A800': 312e12,
'L40S': 362.05e12,
'L40': 181.05e12,
'A40': 149.7e12,
'L20': 119.5e12,
'H20': 148e12,
'910B': 354e12,
'Ascend910': 354e12,
'RTX 3070 Ti': 21.75e12
}

extra_callbacks = [PerfMetricsLogCallback()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not enable by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. Is a switch argument required?

# This example shows a simple example of EarlyStop Callback, uncomment this to use
# extra_callbacks = [EarlyStopCallback()]