|
8 | 8 | import os |
9 | 9 | import sys |
10 | 10 | import time |
| 11 | +from tqdm import tqdm |
11 | 12 | from dataclasses import dataclass, field, fields |
12 | 13 | from typing import (TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union, cast) |
13 | 14 |
|
@@ -3615,7 +3616,7 @@ def log_warmup(self, phase, i, max_i, first_dim, second_dim, third_dim, causal=F |
3615 | 3616 | f"query_len:{second_dim} " |
3616 | 3617 | f"num_blocks:{third_dim} " |
3617 | 3618 | f"free_mem:{free_mem}") |
3618 | | - logger.info(msg) |
| 3619 | + tqdm.write(msg) |
3619 | 3620 |
|
3620 | 3621 | def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, img_args): |
3621 | 3622 | free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory()) |
@@ -3765,45 +3766,57 @@ def warmup_graphs(self, buckets, is_prompt, kv_caches, starting_mem=0, total_bat |
3765 | 3766 | idx = 0 |
3766 | 3767 | num_candidates = len(buckets) |
3767 | 3768 | captured_all = True |
3768 | | - for idx, (batch_size, seq_len, num_blocks) in enumerate(reversed(buckets)): |
3769 | | - if seq_len > self.max_num_tokens: |
3770 | | - continue |
3771 | | - # Graph memory usage is proportional to seq dimension in a batch |
3772 | | - phase = f"Graph/{'prompt' if is_prompt else 'decode'}" |
3773 | | - if is_prompt: |
3774 | | - batch_seq = batch_size * seq_len * num_blocks if num_blocks else batch_size * seq_len |
3775 | | - else: |
3776 | | - batch_seq = batch_size |
3777 | | - |
3778 | | - graphed_bucket = (batch_size, seq_len, num_blocks, is_prompt) |
3779 | | - if graphed_bucket in self.graphed_buckets: |
3780 | | - continue |
3781 | | - self.graphed_buckets.add(graphed_bucket) |
3782 | | - self.log_warmup(phase, idx, num_candidates, batch_size, seq_len, num_blocks) |
3783 | | - prompt_cfg, decode_cfg = None, None |
3784 | | - with HabanaMemoryProfiler() as mem_prof: |
| 3769 | + developer_settings = get_config().VLLM_DEVELOPER_MODE |
| 3770 | + phase = 'Prompt' if is_prompt else 'Decode' |
| 3771 | + desc = f'{phase} warmup processing: ' |
| 3772 | + with tqdm(total=num_candidates, desc=desc, unit="item") as pbar: |
| 3773 | + for idx, (batch_size, seq_len, num_blocks) in enumerate(reversed(buckets)): |
| 3774 | + if seq_len > self.max_num_tokens: |
| 3775 | + continue |
| 3776 | + # Graph memory usage is proportional to seq dimension in a batch |
3785 | 3777 | if is_prompt: |
3786 | | - prompt_cfg = (batch_size, seq_len, num_blocks) |
| 3778 | + batch_seq = batch_size * seq_len * num_blocks if num_blocks else batch_size * seq_len |
3787 | 3779 | else: |
3788 | | - decode_cfg = (batch_size, 1, num_blocks) |
3789 | | - self._prepare_dummy_scenario(prompt_cfg, decode_cfg) |
3790 | | - # TODO(kzawora): align_workers |
3791 | | - used_mem = mem_prof.consumed_device_memory |
3792 | | - total_mem += used_mem |
3793 | | - total_batch_seq += batch_seq |
| 3780 | + batch_seq = batch_size |
| 3781 | + |
| 3782 | + graphed_bucket = (batch_size, seq_len, num_blocks, is_prompt) |
| 3783 | + if graphed_bucket in self.graphed_buckets: |
| 3784 | + continue |
| 3785 | + self.graphed_buckets.add(graphed_bucket) |
| 3786 | + if developer_settings: |
| 3787 | + self.log_warmup(phase, idx, num_candidates, batch_size, seq_len, num_blocks) |
| 3788 | + prompt_cfg, decode_cfg = None, None |
| 3789 | + with HabanaMemoryProfiler() as mem_prof: |
| 3790 | + if is_prompt: |
| 3791 | + prompt_cfg = (batch_size, seq_len, num_blocks) |
| 3792 | + else: |
| 3793 | + decode_cfg = (batch_size, 1, num_blocks) |
| 3794 | + self._prepare_dummy_scenario(prompt_cfg, decode_cfg) |
| 3795 | + # TODO(kzawora): align_workers |
| 3796 | + used_mem = mem_prof.consumed_device_memory |
| 3797 | + total_mem += used_mem |
| 3798 | + total_batch_seq += batch_seq |
| 3799 | + |
| 3800 | + pbar.set_postfix_str(f"{idx}/{num_candidates}") |
| 3801 | + pbar.update(1) |
3794 | 3802 |
|
3795 | 3803 | return total_mem, total_batch_seq, captured_all |
3796 | 3804 |
|
3797 | 3805 | def warmup_unified_graphs(self, buckets, kv_cache): |
3798 | 3806 | idx = 0 |
3799 | 3807 | num_candidates = len(buckets) |
3800 | | - for idx, (query, shared_ctx, unique_ctx, is_causal) in enumerate(reversed(buckets)): |
3801 | | - unified_cfg = (query, shared_ctx, unique_ctx, is_causal) |
3802 | | - if unified_cfg in self.graphed_buckets: |
3803 | | - continue |
3804 | | - self.graphed_buckets.add(unified_cfg) |
3805 | | - self.log_warmup("Unified CFG", idx, num_candidates, query, shared_ctx, unique_ctx, is_causal) |
3806 | | - self._prepare_dummy_unified_scenario(unified_cfg) |
| 3808 | + developer_settings = get_config().VLLM_DEVELOPER_MODE |
| 3809 | + with tqdm(total=num_candidates, desc="Unified Attention warmup", unit="item") as pbar: |
| 3810 | + for idx, (query, shared_ctx, unique_ctx, is_causal) in enumerate(reversed(buckets)): |
| 3811 | + unified_cfg = (query, shared_ctx, unique_ctx, is_causal) |
| 3812 | + if unified_cfg in self.graphed_buckets: |
| 3813 | + continue |
| 3814 | + self.graphed_buckets.add(unified_cfg) |
| 3815 | + if developer_settings: |
| 3816 | + self.log_warmup("Unified CFG", idx, num_candidates, query, shared_ctx, unique_ctx, is_causal) |
| 3817 | + self._prepare_dummy_unified_scenario(unified_cfg) |
| 3818 | + pbar.set_postfix_str(f"{idx}/{num_candidates}") |
| 3819 | + pbar.update(1) |
3807 | 3820 |
|
3808 | 3821 | def _add_dummy_request(self, |
3809 | 3822 | requests, |
|
0 commit comments