Skip to content
This repository was archived by the owner on Dec 1, 2024. It is now read-only.

Commit 8b89e02

Browse files
committed
support qwen models
1 parent 8e2cc94 commit 8b89e02

File tree

3 files changed

+641
-0
lines changed

3 files changed

+641
-0
lines changed

flexgen/flex_qwen.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""
2+
Usage:
3+
python3 -m flexgen.flex_qwen --model Qwen/Qwen1.5-0.5B-Chat --gpu-batch-size 32 --percent 100 0 100 0 100 0
4+
"""
5+
import os
6+
import torch
7+
import argparse
8+
from typing import Union
9+
from transformers import AutoTokenizer
10+
from flexgen.compression import CompressionConfig
11+
from flexgen.qwen_config import QwenConfig, get_qwen_config, download_qwen_weights
12+
from flexgen.flex_llama import LlamaInputEmbed, LlamaOutputEmbed, LlamaMLP
13+
from flexgen.pytorch_backend import QwenTorchDevice, TorchDisk, TorchMixedDevice, fix_recursive_import
14+
from flexgen.flex_opt import (Policy, init_weight_list, SelfAttention, TransformerLayer,
15+
OptLM, get_filename, get_test_inputs)
16+
from flexgen.timer import timers
17+
from flexgen.utils import (ExecutionEnv, GB, ValueHolder,
18+
array_1d, array_2d, str2bool, project_decode_latency, write_benchmark_log)
19+
20+
fix_recursive_import()
21+
22+
DUMMY_WEIGHT = "_DUMMY_" # Use dummy weights for benchmark purposes
23+
24+
25+
class QwenSelfAttention(SelfAttention):
26+
def __init__(self, config, env, policy, layer_id):
27+
super().__init__(config, env, policy, layer_id)
28+
29+
def init_weight(self, weight_home, path):
30+
h, n_head, n_kv_head, dtype = (self.config.input_dim, self.config.n_head, self.config.num_key_value_heads, self.config.dtype)
31+
head_dim = h // n_head
32+
path = os.path.join(os.path.join(path, f"layers.{self.layer_id}."))
33+
weight_specs = [
34+
# w_ln
35+
((h,), dtype, path + "input_layernorm.weight"),
36+
# w_q
37+
((h, n_head*head_dim), dtype, path + "self_attn.q_proj.weight"),
38+
# b_q
39+
((n_head*head_dim,), dtype, path + "self_attn.q_proj.bias"),
40+
# w_k
41+
((n_kv_head*head_dim, h), dtype, path + "self_attn.k_proj.weight"),
42+
# b_k
43+
((h,), dtype, path + "self_attn.k_proj.bias"),
44+
# w_v
45+
((n_kv_head*head_dim, h), dtype, path + "self_attn.v_proj.weight"),
46+
# b_v
47+
((h,), dtype, path + "self_attn.v_proj.bias"),
48+
# w_o
49+
((n_head*head_dim, h), dtype, path + "self_attn.o_proj.weight"),
50+
]
51+
weights = init_weight_list(weight_specs, self.policy, self.env)
52+
weight_home.store(weights)
53+
54+
def load_weight(self, weight_home, weight_read_buf, k):
55+
w_ln, w_q, b_q, w_k, b_k, w_v, b_v, w_o = weight_home.val
56+
if k == 0:
57+
dst1 = self.weight_load_dst
58+
dst2 = self.compute
59+
weight_read_buf.store((
60+
w_ln.smart_copy(dst2),
61+
w_q.smart_copy(dst1), b_q.smart_copy(dst2),
62+
w_k.smart_copy(dst1), b_k.smart_copy(dst2),
63+
w_v.smart_copy(dst1), b_v.smart_copy(dst2),
64+
w_o.smart_copy(dst1)))
65+
66+
def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask,
67+
cache_write_buf, i, k):
68+
n_head = self.config.n_head
69+
n_kv_head = self.config.num_key_value_heads
70+
71+
donate = [False] * 12
72+
h, donate[0] = hidden.val, True
73+
74+
if k == self.policy.num_gpu_batches - 1:
75+
# Clear the weight_read_buf if it is the last gpu batch
76+
((w_ln, donate[2]), (w_q, donate[3]), (b_q, donate[4]), (w_k, donate[5]), (b_k, donate[6]),
77+
(w_v, donate[7]), (b_v, donate[8]), (w_o, donate[9])) = weight_read_buf.pop()
78+
else:
79+
((w_ln, _), (w_q, _), (b_q, _), (w_k, _), (b_k, _), (w_v, _), (b_v, _),
80+
(w_o, _)) = weight_read_buf.val
81+
82+
if i == 0: # prefill
83+
mask, donate[1] = attention_mask.val.smart_copy(self.compute)
84+
position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1
85+
h, new_k_cache, new_v_cache = self.compute.qwen_mha(h, position_ids, mask, w_ln,
86+
w_q, b_q, w_k, b_k, w_v, b_v, w_o, n_head, n_kv_head, donate, self.config.rms_norm_eps, self.config.rope_theta,
87+
self.policy.compress_cache, self.policy.comp_cache_config)
88+
cache_write_buf.store((new_k_cache, new_v_cache))
89+
else: # decoding
90+
mask, donate[1] = attention_mask.val.smart_copy(self.attention_compute)
91+
(k_cache, donate[10]), (v_cache, donate[11]) = cache_read_buf.pop()
92+
position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1
93+
position_ids = position_ids[:, -h.shape[1]].unsqueeze(1)
94+
h, new_k_cache, new_v_cache = self.compute.qwen_mha_gen(h, position_ids, mask, w_ln,
95+
w_q, b_q, w_k, b_k, w_v, b_v, w_o, self.config.rms_norm_eps, self.config.rope_theta, n_head, n_kv_head,
96+
k_cache, v_cache, donate, self.policy.attn_sparsity,
97+
self.policy.compress_cache, self.policy.comp_cache_config)
98+
cache_write_buf.store((new_k_cache, new_v_cache))
99+
100+
hidden.val = h
101+
102+
103+
class QwenTransformerLayer(TransformerLayer):
104+
def __init__(self, config, env, policy, i):
105+
self.attention = QwenSelfAttention(config, env, policy, i)
106+
self.mlp = LlamaMLP(config, env, policy, i)
107+
self.policy = policy
108+
self.compute = self.attention.compute
109+
110+
111+
class QwenLM(OptLM):
112+
def __init__(self,
113+
config: Union[str, QwenConfig],
114+
env: ExecutionEnv,
115+
path: str,
116+
policy: Policy):
117+
if isinstance(config, str):
118+
config = get_qwen_config(config)
119+
self.config = config
120+
self.env = env
121+
self.path = path
122+
self.policy = policy
123+
self.num_gpu_batches = policy.num_gpu_batches
124+
125+
layers = []
126+
layers.append(LlamaInputEmbed(self.config, self.env, self.policy))
127+
for i in range(self.config.num_hidden_layers):
128+
if policy.sep_layer:
129+
layers.append(QwenSelfAttention(self.config, self.env, self.policy, i))
130+
layers.append(LlamaMLP(self.config, self.env, self.policy, i))
131+
else:
132+
layers.append(QwenTransformerLayer(self.config, self.env, self.policy, i))
133+
layers.append(LlamaOutputEmbed(self.config, self.env, self.policy))
134+
self.layers = layers
135+
self.num_layers = len(layers)
136+
137+
if self.policy.act_gpu_percent == 100:
138+
self.act_home = self.env.gpu
139+
elif self.policy.act_cpu_percent == 100:
140+
self.act_home = self.env.cpu
141+
elif self.policy.act_disk_percent == 100:
142+
self.act_home = self.env.disk
143+
else:
144+
raise NotImplementedError()
145+
146+
# CUDA streams
147+
self.load_weight_stream = torch.cuda.Stream()
148+
self.load_cache_stream = torch.cuda.Stream()
149+
self.store_cache_stream = torch.cuda.Stream()
150+
151+
# Intermediate tensors
152+
# The following buffers store values used
153+
# for the i-th token, j-th layer, k-th gpu batch.
154+
num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches
155+
156+
# cache[j][k]
157+
self.cache_home = array_2d(num_layers, num_gpu_batches, ValueHolder)
158+
self.cache_read_buf = array_2d(num_layers, num_gpu_batches, ValueHolder)
159+
self.cache_write_buf = array_2d(num_layers, num_gpu_batches, ValueHolder)
160+
# weight[j]
161+
self.weight_read_buf = array_1d(num_layers, ValueHolder)
162+
# attention_mask[k]
163+
self.attention_mask = array_1d(num_gpu_batches, ValueHolder)
164+
165+
self.task = None
166+
self.init_all_weights()
167+
168+
def init_weight(self, j):
169+
expanded_path = os.path.abspath(os.path.expanduser(
170+
os.path.join(self.path, f"{self.config.name}-np")))
171+
check_path = os.path.join(expanded_path, "embed_tokens.weight")
172+
if not os.path.exists(check_path) and DUMMY_WEIGHT not in check_path:
173+
download_qwen_weights(self.config.name, self.path)
174+
175+
self.layers[j].init_weight(self.weight_home[j], expanded_path)
176+
177+
178+
def run_flexgen(args):
179+
print(f"<run_flexgen>: args.model: {args.model}")
180+
tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left")
181+
tokenizer.pad_token_id = tokenizer.eos_token_id
182+
num_prompts = args.num_gpu_batches * args.gpu_batch_size
183+
prompt_len, gen_len, cut_gen_len = args.prompt_len, args.gen_len, args.cut_gen_len
184+
185+
# Task and policy
186+
warmup_inputs = get_test_inputs(32, num_prompts, tokenizer)
187+
inputs = get_test_inputs(prompt_len, num_prompts, tokenizer)
188+
189+
gpu = QwenTorchDevice("cuda:0")
190+
cpu = QwenTorchDevice("cpu")
191+
disk = TorchDisk(args.offload_dir)
192+
env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
193+
194+
policy = Policy(args.gpu_batch_size, args.num_gpu_batches,
195+
args.percent[0], args.percent[1],
196+
args.percent[2], args.percent[3],
197+
args.percent[4], args.percent[5],
198+
args.overlap, args.sep_layer, args.pin_weight,
199+
args.cpu_cache_compute, args.attn_sparsity,
200+
args.compress_weight,
201+
CompressionConfig(num_bits=4, group_size=64,
202+
group_dim=0, symmetric=False),
203+
args.compress_cache,
204+
CompressionConfig(num_bits=4, group_size=64,
205+
group_dim=2, symmetric=False))
206+
assert not (args.compress_cache and args.attn_sparsity < 1.0), "Not implemented"
207+
208+
qwen_config = get_qwen_config(args.model, pad_token_id=tokenizer.eos_token_id)
209+
cache_size = qwen_config.cache_bytes(num_prompts, prompt_len + gen_len)
210+
hidden_size = qwen_config.hidden_bytes(num_prompts, prompt_len + gen_len)
211+
print(f"model size: {qwen_config.model_bytes()/GB:.3f} GB, "
212+
f"cache size: {cache_size/GB:.3f} GB, "
213+
f"hidden size (prefill): {hidden_size/GB:.3f} GB")
214+
215+
print("init weight...")
216+
model = QwenLM(qwen_config, env, args.path, policy)
217+
218+
try:
219+
print("warmup - generate")
220+
output_ids = model.generate(
221+
warmup_inputs, max_new_tokens=1, verbose=args.verbose)
222+
223+
print("benchmark - generate")
224+
timers("generate").reset()
225+
output_ids = model.generate(
226+
inputs, max_new_tokens=args.gen_len,
227+
debug_mode=args.debug_mode, cut_gen_len=cut_gen_len, verbose=args.verbose)
228+
costs = timers("generate").costs
229+
finally:
230+
env.close_copy_threads()
231+
232+
# Log output
233+
prefill_latency = costs[0]
234+
prefill_throughput = num_prompts * prompt_len / prefill_latency
235+
if cut_gen_len: # project latency of cut_gen_len to gen_len
236+
decode_latency = project_decode_latency(costs, prompt_len, gen_len)
237+
else:
238+
decode_latency = sum(costs[1:])
239+
decode_throughput = num_prompts * (gen_len - 1) / max(decode_latency, 1e-10)
240+
num_generated_tokens = num_prompts * gen_len
241+
total_latency = prefill_latency + decode_latency
242+
total_throughput = num_generated_tokens / total_latency
243+
_, gpu_peak_mem = gpu.mem_stats()
244+
_, cpu_peak_mem = cpu.mem_stats()
245+
246+
if DUMMY_WEIGHT not in args.path:
247+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
248+
show_str = "Outputs:\n" + 70 * '-' + "\n"
249+
for i in [0, len(outputs)-1]:
250+
show_str += f"{i}: {outputs[i]}\n"
251+
show_str += "-" * 70 + "\n"
252+
if args.verbose >= 2:
253+
print(show_str)
254+
255+
gpu.print_stats()
256+
cpu.print_stats()
257+
projected = bool(args.debug_mode or cut_gen_len)
258+
259+
if args.log_file == "auto":
260+
filename = get_filename(args) + ".log"
261+
else:
262+
filename = args.log_file
263+
264+
log_str = write_benchmark_log(filename,
265+
qwen_config.model_bytes(), cache_size, hidden_size,
266+
gpu_peak_mem, projected, prefill_latency, prefill_throughput,
267+
decode_latency, decode_throughput, total_latency, total_throughput)
268+
if args.verbose >= 1:
269+
print(log_str)
270+
271+
272+
def add_parser_arguments(parser):
273+
parser.add_argument("--model", type=str, default="Qwen/Qwen1.5-7B-Chat",
274+
help="The model name.")
275+
parser.add_argument("--path", type=str, default="~/qwen_weights",
276+
help="The path to the model weights. If there are no cached weights, "
277+
"FlexGen will automatically download them from HuggingFace.")
278+
parser.add_argument("--offload-dir", type=str, default="~/flexgen_offload_dir",
279+
help="The directory to offload tensors. ")
280+
parser.add_argument("--prompt-len", type=int, default=512)
281+
parser.add_argument("--gen-len", type=int, default=32)
282+
parser.add_argument("--cut-gen-len", type=int,
283+
help="Cut generation length for fast debugging.")
284+
parser.add_argument("--debug-mode", type=str,
285+
choices=["fewer_batch", "breakdown"])
286+
parser.add_argument("--gpu-batch-size", type=int, default=4)
287+
parser.add_argument("--num-gpu-batches", type=int, default=1)
288+
parser.add_argument("--percent", nargs="+", type=int,
289+
default=[100, 0, 100, 0, 100, 0],
290+
help="Six numbers. They are "
291+
"the percentage of weight on GPU, "
292+
"the percentage of weight on CPU, "
293+
"the percentage of attention cache on GPU, "
294+
"the percentage of attention cache on CPU, "
295+
"the percentage of activations on GPU, "
296+
"the percentage of activations on CPU")
297+
parser.add_argument("--sep-layer", type=str2bool, nargs='?',
298+
const=True, default=True)
299+
parser.add_argument("--pin-weight", type=str2bool, nargs="?",
300+
const=True, default=True)
301+
parser.add_argument("--cpu-cache-compute", action="store_true")
302+
parser.add_argument("--attn-sparsity", type=float, default=1.0)
303+
parser.add_argument("--compress-weight", action="store_true",
304+
help="Whether to compress weight.")
305+
parser.add_argument("--compress-cache", action="store_true",
306+
help="Whether to compress cache.")
307+
parser.add_argument("--log-file", type=str, default="auto")
308+
parser.add_argument("--no-log", action="store_true")
309+
parser.add_argument("--verbose", type=int, default=2)
310+
parser.add_argument("--overlap", type=str2bool, nargs='?',
311+
const=True, default=True)
312+
313+
314+
if __name__ == "__main__":
315+
parser = argparse.ArgumentParser()
316+
add_parser_arguments(parser)
317+
args = parser.parse_args()
318+
319+
assert len(args.percent) == 6
320+
321+
run_flexgen(args)

0 commit comments

Comments
 (0)