|
| 1 | +""" |
| 2 | +Usage: |
| 3 | +python3 -m flexgen.flex_qwen --model Qwen/Qwen1.5-0.5B-Chat --gpu-batch-size 32 --percent 100 0 100 0 100 0 |
| 4 | +""" |
| 5 | +import os |
| 6 | +import torch |
| 7 | +import argparse |
| 8 | +from typing import Union |
| 9 | +from transformers import AutoTokenizer |
| 10 | +from flexgen.compression import CompressionConfig |
| 11 | +from flexgen.qwen_config import QwenConfig, get_qwen_config, download_qwen_weights |
| 12 | +from flexgen.flex_llama import LlamaInputEmbed, LlamaOutputEmbed, LlamaMLP |
| 13 | +from flexgen.pytorch_backend import QwenTorchDevice, TorchDisk, TorchMixedDevice, fix_recursive_import |
| 14 | +from flexgen.flex_opt import (Policy, init_weight_list, SelfAttention, TransformerLayer, |
| 15 | + OptLM, get_filename, get_test_inputs) |
| 16 | +from flexgen.timer import timers |
| 17 | +from flexgen.utils import (ExecutionEnv, GB, ValueHolder, |
| 18 | + array_1d, array_2d, str2bool, project_decode_latency, write_benchmark_log) |
| 19 | + |
| 20 | +fix_recursive_import() |
| 21 | + |
| 22 | +DUMMY_WEIGHT = "_DUMMY_" # Use dummy weights for benchmark purposes |
| 23 | + |
| 24 | + |
| 25 | +class QwenSelfAttention(SelfAttention): |
| 26 | + def __init__(self, config, env, policy, layer_id): |
| 27 | + super().__init__(config, env, policy, layer_id) |
| 28 | + |
| 29 | + def init_weight(self, weight_home, path): |
| 30 | + h, n_head, n_kv_head, dtype = (self.config.input_dim, self.config.n_head, self.config.num_key_value_heads, self.config.dtype) |
| 31 | + head_dim = h // n_head |
| 32 | + path = os.path.join(os.path.join(path, f"layers.{self.layer_id}.")) |
| 33 | + weight_specs = [ |
| 34 | + # w_ln |
| 35 | + ((h,), dtype, path + "input_layernorm.weight"), |
| 36 | + # w_q |
| 37 | + ((h, n_head*head_dim), dtype, path + "self_attn.q_proj.weight"), |
| 38 | + # b_q |
| 39 | + ((n_head*head_dim,), dtype, path + "self_attn.q_proj.bias"), |
| 40 | + # w_k |
| 41 | + ((n_kv_head*head_dim, h), dtype, path + "self_attn.k_proj.weight"), |
| 42 | + # b_k |
| 43 | + ((h,), dtype, path + "self_attn.k_proj.bias"), |
| 44 | + # w_v |
| 45 | + ((n_kv_head*head_dim, h), dtype, path + "self_attn.v_proj.weight"), |
| 46 | + # b_v |
| 47 | + ((h,), dtype, path + "self_attn.v_proj.bias"), |
| 48 | + # w_o |
| 49 | + ((n_head*head_dim, h), dtype, path + "self_attn.o_proj.weight"), |
| 50 | + ] |
| 51 | + weights = init_weight_list(weight_specs, self.policy, self.env) |
| 52 | + weight_home.store(weights) |
| 53 | + |
| 54 | + def load_weight(self, weight_home, weight_read_buf, k): |
| 55 | + w_ln, w_q, b_q, w_k, b_k, w_v, b_v, w_o = weight_home.val |
| 56 | + if k == 0: |
| 57 | + dst1 = self.weight_load_dst |
| 58 | + dst2 = self.compute |
| 59 | + weight_read_buf.store(( |
| 60 | + w_ln.smart_copy(dst2), |
| 61 | + w_q.smart_copy(dst1), b_q.smart_copy(dst2), |
| 62 | + w_k.smart_copy(dst1), b_k.smart_copy(dst2), |
| 63 | + w_v.smart_copy(dst1), b_v.smart_copy(dst2), |
| 64 | + w_o.smart_copy(dst1))) |
| 65 | + |
| 66 | + def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, |
| 67 | + cache_write_buf, i, k): |
| 68 | + n_head = self.config.n_head |
| 69 | + n_kv_head = self.config.num_key_value_heads |
| 70 | + |
| 71 | + donate = [False] * 12 |
| 72 | + h, donate[0] = hidden.val, True |
| 73 | + |
| 74 | + if k == self.policy.num_gpu_batches - 1: |
| 75 | + # Clear the weight_read_buf if it is the last gpu batch |
| 76 | + ((w_ln, donate[2]), (w_q, donate[3]), (b_q, donate[4]), (w_k, donate[5]), (b_k, donate[6]), |
| 77 | + (w_v, donate[7]), (b_v, donate[8]), (w_o, donate[9])) = weight_read_buf.pop() |
| 78 | + else: |
| 79 | + ((w_ln, _), (w_q, _), (b_q, _), (w_k, _), (b_k, _), (w_v, _), (b_v, _), |
| 80 | + (w_o, _)) = weight_read_buf.val |
| 81 | + |
| 82 | + if i == 0: # prefill |
| 83 | + mask, donate[1] = attention_mask.val.smart_copy(self.compute) |
| 84 | + position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1 |
| 85 | + h, new_k_cache, new_v_cache = self.compute.qwen_mha(h, position_ids, mask, w_ln, |
| 86 | + w_q, b_q, w_k, b_k, w_v, b_v, w_o, n_head, n_kv_head, donate, self.config.rms_norm_eps, self.config.rope_theta, |
| 87 | + self.policy.compress_cache, self.policy.comp_cache_config) |
| 88 | + cache_write_buf.store((new_k_cache, new_v_cache)) |
| 89 | + else: # decoding |
| 90 | + mask, donate[1] = attention_mask.val.smart_copy(self.attention_compute) |
| 91 | + (k_cache, donate[10]), (v_cache, donate[11]) = cache_read_buf.pop() |
| 92 | + position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1 |
| 93 | + position_ids = position_ids[:, -h.shape[1]].unsqueeze(1) |
| 94 | + h, new_k_cache, new_v_cache = self.compute.qwen_mha_gen(h, position_ids, mask, w_ln, |
| 95 | + w_q, b_q, w_k, b_k, w_v, b_v, w_o, self.config.rms_norm_eps, self.config.rope_theta, n_head, n_kv_head, |
| 96 | + k_cache, v_cache, donate, self.policy.attn_sparsity, |
| 97 | + self.policy.compress_cache, self.policy.comp_cache_config) |
| 98 | + cache_write_buf.store((new_k_cache, new_v_cache)) |
| 99 | + |
| 100 | + hidden.val = h |
| 101 | + |
| 102 | + |
| 103 | +class QwenTransformerLayer(TransformerLayer): |
| 104 | + def __init__(self, config, env, policy, i): |
| 105 | + self.attention = QwenSelfAttention(config, env, policy, i) |
| 106 | + self.mlp = LlamaMLP(config, env, policy, i) |
| 107 | + self.policy = policy |
| 108 | + self.compute = self.attention.compute |
| 109 | + |
| 110 | + |
| 111 | +class QwenLM(OptLM): |
| 112 | + def __init__(self, |
| 113 | + config: Union[str, QwenConfig], |
| 114 | + env: ExecutionEnv, |
| 115 | + path: str, |
| 116 | + policy: Policy): |
| 117 | + if isinstance(config, str): |
| 118 | + config = get_qwen_config(config) |
| 119 | + self.config = config |
| 120 | + self.env = env |
| 121 | + self.path = path |
| 122 | + self.policy = policy |
| 123 | + self.num_gpu_batches = policy.num_gpu_batches |
| 124 | + |
| 125 | + layers = [] |
| 126 | + layers.append(LlamaInputEmbed(self.config, self.env, self.policy)) |
| 127 | + for i in range(self.config.num_hidden_layers): |
| 128 | + if policy.sep_layer: |
| 129 | + layers.append(QwenSelfAttention(self.config, self.env, self.policy, i)) |
| 130 | + layers.append(LlamaMLP(self.config, self.env, self.policy, i)) |
| 131 | + else: |
| 132 | + layers.append(QwenTransformerLayer(self.config, self.env, self.policy, i)) |
| 133 | + layers.append(LlamaOutputEmbed(self.config, self.env, self.policy)) |
| 134 | + self.layers = layers |
| 135 | + self.num_layers = len(layers) |
| 136 | + |
| 137 | + if self.policy.act_gpu_percent == 100: |
| 138 | + self.act_home = self.env.gpu |
| 139 | + elif self.policy.act_cpu_percent == 100: |
| 140 | + self.act_home = self.env.cpu |
| 141 | + elif self.policy.act_disk_percent == 100: |
| 142 | + self.act_home = self.env.disk |
| 143 | + else: |
| 144 | + raise NotImplementedError() |
| 145 | + |
| 146 | + # CUDA streams |
| 147 | + self.load_weight_stream = torch.cuda.Stream() |
| 148 | + self.load_cache_stream = torch.cuda.Stream() |
| 149 | + self.store_cache_stream = torch.cuda.Stream() |
| 150 | + |
| 151 | + # Intermediate tensors |
| 152 | + # The following buffers store values used |
| 153 | + # for the i-th token, j-th layer, k-th gpu batch. |
| 154 | + num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches |
| 155 | + |
| 156 | + # cache[j][k] |
| 157 | + self.cache_home = array_2d(num_layers, num_gpu_batches, ValueHolder) |
| 158 | + self.cache_read_buf = array_2d(num_layers, num_gpu_batches, ValueHolder) |
| 159 | + self.cache_write_buf = array_2d(num_layers, num_gpu_batches, ValueHolder) |
| 160 | + # weight[j] |
| 161 | + self.weight_read_buf = array_1d(num_layers, ValueHolder) |
| 162 | + # attention_mask[k] |
| 163 | + self.attention_mask = array_1d(num_gpu_batches, ValueHolder) |
| 164 | + |
| 165 | + self.task = None |
| 166 | + self.init_all_weights() |
| 167 | + |
| 168 | + def init_weight(self, j): |
| 169 | + expanded_path = os.path.abspath(os.path.expanduser( |
| 170 | + os.path.join(self.path, f"{self.config.name}-np"))) |
| 171 | + check_path = os.path.join(expanded_path, "embed_tokens.weight") |
| 172 | + if not os.path.exists(check_path) and DUMMY_WEIGHT not in check_path: |
| 173 | + download_qwen_weights(self.config.name, self.path) |
| 174 | + |
| 175 | + self.layers[j].init_weight(self.weight_home[j], expanded_path) |
| 176 | + |
| 177 | + |
| 178 | +def run_flexgen(args): |
| 179 | + print(f"<run_flexgen>: args.model: {args.model}") |
| 180 | + tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left") |
| 181 | + tokenizer.pad_token_id = tokenizer.eos_token_id |
| 182 | + num_prompts = args.num_gpu_batches * args.gpu_batch_size |
| 183 | + prompt_len, gen_len, cut_gen_len = args.prompt_len, args.gen_len, args.cut_gen_len |
| 184 | + |
| 185 | + # Task and policy |
| 186 | + warmup_inputs = get_test_inputs(32, num_prompts, tokenizer) |
| 187 | + inputs = get_test_inputs(prompt_len, num_prompts, tokenizer) |
| 188 | + |
| 189 | + gpu = QwenTorchDevice("cuda:0") |
| 190 | + cpu = QwenTorchDevice("cpu") |
| 191 | + disk = TorchDisk(args.offload_dir) |
| 192 | + env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) |
| 193 | + |
| 194 | + policy = Policy(args.gpu_batch_size, args.num_gpu_batches, |
| 195 | + args.percent[0], args.percent[1], |
| 196 | + args.percent[2], args.percent[3], |
| 197 | + args.percent[4], args.percent[5], |
| 198 | + args.overlap, args.sep_layer, args.pin_weight, |
| 199 | + args.cpu_cache_compute, args.attn_sparsity, |
| 200 | + args.compress_weight, |
| 201 | + CompressionConfig(num_bits=4, group_size=64, |
| 202 | + group_dim=0, symmetric=False), |
| 203 | + args.compress_cache, |
| 204 | + CompressionConfig(num_bits=4, group_size=64, |
| 205 | + group_dim=2, symmetric=False)) |
| 206 | + assert not (args.compress_cache and args.attn_sparsity < 1.0), "Not implemented" |
| 207 | + |
| 208 | + qwen_config = get_qwen_config(args.model, pad_token_id=tokenizer.eos_token_id) |
| 209 | + cache_size = qwen_config.cache_bytes(num_prompts, prompt_len + gen_len) |
| 210 | + hidden_size = qwen_config.hidden_bytes(num_prompts, prompt_len + gen_len) |
| 211 | + print(f"model size: {qwen_config.model_bytes()/GB:.3f} GB, " |
| 212 | + f"cache size: {cache_size/GB:.3f} GB, " |
| 213 | + f"hidden size (prefill): {hidden_size/GB:.3f} GB") |
| 214 | + |
| 215 | + print("init weight...") |
| 216 | + model = QwenLM(qwen_config, env, args.path, policy) |
| 217 | + |
| 218 | + try: |
| 219 | + print("warmup - generate") |
| 220 | + output_ids = model.generate( |
| 221 | + warmup_inputs, max_new_tokens=1, verbose=args.verbose) |
| 222 | + |
| 223 | + print("benchmark - generate") |
| 224 | + timers("generate").reset() |
| 225 | + output_ids = model.generate( |
| 226 | + inputs, max_new_tokens=args.gen_len, |
| 227 | + debug_mode=args.debug_mode, cut_gen_len=cut_gen_len, verbose=args.verbose) |
| 228 | + costs = timers("generate").costs |
| 229 | + finally: |
| 230 | + env.close_copy_threads() |
| 231 | + |
| 232 | + # Log output |
| 233 | + prefill_latency = costs[0] |
| 234 | + prefill_throughput = num_prompts * prompt_len / prefill_latency |
| 235 | + if cut_gen_len: # project latency of cut_gen_len to gen_len |
| 236 | + decode_latency = project_decode_latency(costs, prompt_len, gen_len) |
| 237 | + else: |
| 238 | + decode_latency = sum(costs[1:]) |
| 239 | + decode_throughput = num_prompts * (gen_len - 1) / max(decode_latency, 1e-10) |
| 240 | + num_generated_tokens = num_prompts * gen_len |
| 241 | + total_latency = prefill_latency + decode_latency |
| 242 | + total_throughput = num_generated_tokens / total_latency |
| 243 | + _, gpu_peak_mem = gpu.mem_stats() |
| 244 | + _, cpu_peak_mem = cpu.mem_stats() |
| 245 | + |
| 246 | + if DUMMY_WEIGHT not in args.path: |
| 247 | + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| 248 | + show_str = "Outputs:\n" + 70 * '-' + "\n" |
| 249 | + for i in [0, len(outputs)-1]: |
| 250 | + show_str += f"{i}: {outputs[i]}\n" |
| 251 | + show_str += "-" * 70 + "\n" |
| 252 | + if args.verbose >= 2: |
| 253 | + print(show_str) |
| 254 | + |
| 255 | + gpu.print_stats() |
| 256 | + cpu.print_stats() |
| 257 | + projected = bool(args.debug_mode or cut_gen_len) |
| 258 | + |
| 259 | + if args.log_file == "auto": |
| 260 | + filename = get_filename(args) + ".log" |
| 261 | + else: |
| 262 | + filename = args.log_file |
| 263 | + |
| 264 | + log_str = write_benchmark_log(filename, |
| 265 | + qwen_config.model_bytes(), cache_size, hidden_size, |
| 266 | + gpu_peak_mem, projected, prefill_latency, prefill_throughput, |
| 267 | + decode_latency, decode_throughput, total_latency, total_throughput) |
| 268 | + if args.verbose >= 1: |
| 269 | + print(log_str) |
| 270 | + |
| 271 | + |
| 272 | +def add_parser_arguments(parser): |
| 273 | + parser.add_argument("--model", type=str, default="Qwen/Qwen1.5-7B-Chat", |
| 274 | + help="The model name.") |
| 275 | + parser.add_argument("--path", type=str, default="~/qwen_weights", |
| 276 | + help="The path to the model weights. If there are no cached weights, " |
| 277 | + "FlexGen will automatically download them from HuggingFace.") |
| 278 | + parser.add_argument("--offload-dir", type=str, default="~/flexgen_offload_dir", |
| 279 | + help="The directory to offload tensors. ") |
| 280 | + parser.add_argument("--prompt-len", type=int, default=512) |
| 281 | + parser.add_argument("--gen-len", type=int, default=32) |
| 282 | + parser.add_argument("--cut-gen-len", type=int, |
| 283 | + help="Cut generation length for fast debugging.") |
| 284 | + parser.add_argument("--debug-mode", type=str, |
| 285 | + choices=["fewer_batch", "breakdown"]) |
| 286 | + parser.add_argument("--gpu-batch-size", type=int, default=4) |
| 287 | + parser.add_argument("--num-gpu-batches", type=int, default=1) |
| 288 | + parser.add_argument("--percent", nargs="+", type=int, |
| 289 | + default=[100, 0, 100, 0, 100, 0], |
| 290 | + help="Six numbers. They are " |
| 291 | + "the percentage of weight on GPU, " |
| 292 | + "the percentage of weight on CPU, " |
| 293 | + "the percentage of attention cache on GPU, " |
| 294 | + "the percentage of attention cache on CPU, " |
| 295 | + "the percentage of activations on GPU, " |
| 296 | + "the percentage of activations on CPU") |
| 297 | + parser.add_argument("--sep-layer", type=str2bool, nargs='?', |
| 298 | + const=True, default=True) |
| 299 | + parser.add_argument("--pin-weight", type=str2bool, nargs="?", |
| 300 | + const=True, default=True) |
| 301 | + parser.add_argument("--cpu-cache-compute", action="store_true") |
| 302 | + parser.add_argument("--attn-sparsity", type=float, default=1.0) |
| 303 | + parser.add_argument("--compress-weight", action="store_true", |
| 304 | + help="Whether to compress weight.") |
| 305 | + parser.add_argument("--compress-cache", action="store_true", |
| 306 | + help="Whether to compress cache.") |
| 307 | + parser.add_argument("--log-file", type=str, default="auto") |
| 308 | + parser.add_argument("--no-log", action="store_true") |
| 309 | + parser.add_argument("--verbose", type=int, default=2) |
| 310 | + parser.add_argument("--overlap", type=str2bool, nargs='?', |
| 311 | + const=True, default=True) |
| 312 | + |
| 313 | + |
| 314 | +if __name__ == "__main__": |
| 315 | + parser = argparse.ArgumentParser() |
| 316 | + add_parser_arguments(parser) |
| 317 | + args = parser.parse_args() |
| 318 | + |
| 319 | + assert len(args.percent) == 6 |
| 320 | + |
| 321 | + run_flexgen(args) |
0 commit comments