-
Notifications
You must be signed in to change notification settings - Fork 1k
Add MFU logging support #6434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MFU logging support #6434
Changes from 3 commits
a820b3c
655eedb
05ba547
190ee31
60d2c5f
f1423a0
83b24f8
83137cd
7d8aa7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,8 +1,11 @@ | ||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||
| from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from swift.utils import get_logger | ||||||||||||||||||||||||||||||||||||||||
| from swift.utils import get_current_device, get_device_count, get_logger | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| logger = get_logger() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -28,6 +31,147 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer | |||||||||||||||||||||||||||||||||||||||
| control.should_training_stop = True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| extra_callbacks = [] | ||||||||||||||||||||||||||||||||||||||||
| class PerfMetricsLogCallback(TrainerCallback): | ||||||||||||||||||||||||||||||||||||||||
| """An callback for perf metrics (MFU etc) log implementation""" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||||
| self.start_time = None | ||||||||||||||||||||||||||||||||||||||||
| self.device_tflops = None | ||||||||||||||||||||||||||||||||||||||||
| self.elapsed = 0.0 | ||||||||||||||||||||||||||||||||||||||||
| self.step_start_time = None | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| from swift.utils import get_env_args | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Top priority. Specify by ENV | ||||||||||||||||||||||||||||||||||||||||
| tflops = get_env_args('DEVICE_TFLOPS', int, None) | ||||||||||||||||||||||||||||||||||||||||
| device_count = max(get_device_count(), 1) | ||||||||||||||||||||||||||||||||||||||||
| if tflops is not None: | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]") | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| # Run a estimating test. | ||||||||||||||||||||||||||||||||||||||||
| dtype = kwargs.get('model').dtype | ||||||||||||||||||||||||||||||||||||||||
| device = torch.device(get_current_device()) | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]') | ||||||||||||||||||||||||||||||||||||||||
| tflops = self._estimate_device_tflops_by_dtype(device, dtype) | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]') | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+50
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| # TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| self.device_tflops = tflops * device_count | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| self.step_start_time = time.time() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| self.elapsed += time.time() - self.step_start_time | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| self.start_time = time.time() | ||||||||||||||||||||||||||||||||||||||||
y2logic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| total_flos = getattr(state, 'total_flos', 0) | ||||||||||||||||||||||||||||||||||||||||
| actual_flops = total_flos / self.elapsed | ||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| theoretical_max_flops = self.device_tflops * 1e12 | ||||||||||||||||||||||||||||||||||||||||
| mfu = actual_flops / theoretical_max_flops | ||||||||||||||||||||||||||||||||||||||||
| logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]') | ||||||||||||||||||||||||||||||||||||||||
| logs['MFU'] = round(mfu, 6) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||
| def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192): | ||||||||||||||||||||||||||||||||||||||||
| # 默认矩阵规模 | ||||||||||||||||||||||||||||||||||||||||
| shape = (dim, dim) | ||||||||||||||||||||||||||||||||||||||||
| backend = device.type | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 创建矩阵 | ||||||||||||||||||||||||||||||||||||||||
| a = torch.randn(*shape, device=device, dtype=dtype) | ||||||||||||||||||||||||||||||||||||||||
| b = torch.randn(*shape, device=device, dtype=dtype) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 预热 | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(5): | ||||||||||||||||||||||||||||||||||||||||
| c = torch.matmul(a, b) | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'cuda': | ||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| torch.npu.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'cpu': | ||||||||||||||||||||||||||||||||||||||||
| torch.cpu.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
y2logic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 进行测试 | ||||||||||||||||||||||||||||||||||||||||
| start = time.time() | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(repeats): | ||||||||||||||||||||||||||||||||||||||||
| c = torch.matmul(a, b) | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'cuda': | ||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| torch.npu.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'cpu': | ||||||||||||||||||||||||||||||||||||||||
| torch.cpu.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| end = time.time() | ||||||||||||||||||||||||||||||||||||||||
| total_time = end - start | ||||||||||||||||||||||||||||||||||||||||
| avg_time = total_time / repeats | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 若测试时间过短,调整循环次数并重新测试 | ||||||||||||||||||||||||||||||||||||||||
| if total_time < 3: | ||||||||||||||||||||||||||||||||||||||||
| repeats = int(6 / avg_time) | ||||||||||||||||||||||||||||||||||||||||
| start = time.time() | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(repeats): | ||||||||||||||||||||||||||||||||||||||||
| c = torch.matmul(a, b) | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'cuda': | ||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| torch.npu.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'cpu': | ||||||||||||||||||||||||||||||||||||||||
| torch.cpu.synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| end = time.time() | ||||||||||||||||||||||||||||||||||||||||
| total_time = end - start | ||||||||||||||||||||||||||||||||||||||||
| avg_time = total_time / repeats | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| del a, b, c | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'cuda': | ||||||||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| torch.npu.empty_cache() | ||||||||||||||||||||||||||||||||||||||||
y2logic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| tflops = (2 * dim**3 / avg_time) / 1e12 | ||||||||||||||||||||||||||||||||||||||||
| print(f'[设备 {device}] 测试总耗时:{total_time:.4f}s,平均耗时: {avg_time:.4f} s,dtype:{dtype},性能: {tflops:.4f} TFLOPS') | ||||||||||||||||||||||||||||||||||||||||
y2logic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return tflops | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||
| def _retrieve_flops_from_map(device): | ||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed this function is not being used. Just curious, what's the reason?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More data should be collected, and hopefully the function |
||||||||||||||||||||||||||||||||||||||||
| """Retrieve theoretical FLOPS from Map. """ | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| device_name = device.get_device_name() | ||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the get_device_name function? |
||||||||||||||||||||||||||||||||||||||||
| flops = None | ||||||||||||||||||||||||||||||||||||||||
| for name, value in device_flops_map: | ||||||||||||||||||||||||||||||||||||||||
y2logic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
| if name in device_name: | ||||||||||||||||||||||||||||||||||||||||
| flops = value | ||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return flops | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| device_flops_map = { | ||||||||||||||||||||||||||||||||||||||||
| 'GB200': 2.5e15, | ||||||||||||||||||||||||||||||||||||||||
| 'B200': 2.25e15, | ||||||||||||||||||||||||||||||||||||||||
| 'MI300X': 1336e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H100': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H800': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H200': 989e12, | ||||||||||||||||||||||||||||||||||||||||
| 'A100': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'A800': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'L40S': 362.05e12, | ||||||||||||||||||||||||||||||||||||||||
| 'L40': 181.05e12, | ||||||||||||||||||||||||||||||||||||||||
| 'A40': 149.7e12, | ||||||||||||||||||||||||||||||||||||||||
| 'L20': 119.5e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H20': 148e12, | ||||||||||||||||||||||||||||||||||||||||
| '910B': 354e12, | ||||||||||||||||||||||||||||||||||||||||
| 'Ascend910': 354e12, | ||||||||||||||||||||||||||||||||||||||||
| 'RTX 3070 Ti': 21.75e12 | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| extra_callbacks = [PerfMetricsLogCallback()] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
| # This example shows a simple example of EarlyStop Callback, uncomment this to use | ||||||||||||||||||||||||||||||||||||||||
| # extra_callbacks = [EarlyStopCallback()] | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DEVICE_TFLOPSis parsed as an integer, which might be too restrictive as TFLOPS values are often floating-point numbers (e.g., from the estimation or lookup table). Usingfloatwould be more appropriate and consistent.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems float is correct.