1+ import torch
2+ from torch .autograd import Variable
3+ from torchvision import models
4+ import cv2
5+ import sys
6+ import numpy as np
7+ import torchvision
8+ import torch .nn as nn
9+ import torch .nn .functional as F
10+ import torch .optim as optim
11+ import dataset
12+ from prune import *
13+ import argparse
14+ from operator import itemgetter
15+ from heapq import nsmallest
16+ import time
17+
18+ class Model (torch .nn .Module ):
19+ def __init__ (self , vgg_model ):
20+ super (Model , self ).__init__ ()
21+ self .features = vgg_model .features
22+ for param in self .features .parameters ():
23+ param .requires_grad = False
24+
25+ self .classifier = nn .Sequential (
26+ nn .Dropout (),
27+ nn .Linear (25088 , 4096 ),
28+ nn .ReLU (inplace = True ),
29+ nn .Dropout (),
30+ nn .Linear (4096 , 4096 ),
31+ nn .ReLU (inplace = True ),
32+ nn .Linear (4096 , 2 ))
33+
34+ def forward (self , x ):
35+ x = self .features (x )
36+ x = x .view (x .size (0 ), - 1 )
37+ x = self .classifier (x )
38+ return x
39+
40+ def train_batch (model , optimizer , criterion , batch , label ):
41+ batch = batch .cuda ()
42+ label = label .cuda ()
43+ output = model (Variable (batch ))
44+ loss = criterion (output , Variable (label ))
45+ model .zero_grad ()
46+ loss .backward ()
47+ optimizer .step ()
48+ return loss
49+
50+ def train_epoch (data_loader , model , optimizer ):
51+ model .train ()
52+ criterion = torch .nn .CrossEntropyLoss ()
53+ for i , (batch , label ) in enumerate (data_loader ):
54+ loss = train_batch (model , optimizer , criterion , batch , label )
55+
56+ def test (data_loader , model ):
57+ model .eval ()
58+ correct = 0
59+ total = 0
60+ for i , (batch , label ) in enumerate (data_loader ):
61+ batch = batch .cuda ()
62+ output = model (Variable (batch ))
63+ pred = output .data .max (1 )[1 ]
64+ correct += pred .cpu ().eq (label ).sum ()
65+ total += label .size (0 )
66+
67+ print "Accuracy :" , float (correct ) / total
68+
69+ def train (train_path , test_path ):
70+ model = Model (torchvision .models .vgg16 (pretrained = True ))
71+ model = model .cuda ()
72+ data_loader = dataset .loader (train_path , pin_memory = True )
73+ test_data_loader = dataset .test_loader (test_path , pin_memory = True )
74+ optimizer = optim .SGD (model .classifier .parameters (), lr = 0.0001 , momentum = 0.9 )
75+
76+ test (test_data_loader , model )
77+ print "Starting.. "
78+ epochs = 10
79+ for i in range (epochs ):
80+ train_epoch (data_loader , model , optimizer )
81+ test (test_data_loader , model )
82+ torch .save (model , "model" )
83+
84+ def random_prune (model ):
85+ test_data_loader = dataset .test_loader (args .test_path , pin_memory = True , batch_size = 64 )
86+ test (test_data_loader , model )
87+ model = model .cpu ()
88+ model .eval ()
89+
90+ l = len (model .features ._modules .items ())
91+ print "Total number of layers" , l
92+ items = [x for x in enumerate (model .features ._modules .items ())]
93+ items = items [::- 5 ]
94+ for i , (name , module ) in items :
95+ if i >= l - 5 :
96+ continue
97+ if isinstance (module , torch .nn .modules .conv .Conv2d ):
98+ filters = model .features ._modules [name ].weight .size (0 )
99+ t0 = time .time ()
100+ for _ in range (filters // 2 ):
101+ model .features = prune_conv_layer (model .features , i , 0 )
102+
103+ print "After pruning.. "
104+ model = model .cuda ()
105+ test (test_data_loader , model )
106+
107+ class ImportanceExtractor :
108+ def __init__ (self , model ):
109+ self .model = model
110+ num_convs = 0
111+ for name , module in self .model .features ._modules .items ():
112+ if isinstance (module , torch .nn .modules .conv .Conv2d ):
113+ num_convs = num_convs + 1
114+
115+ self .importance_values = {}
116+
117+ def __call__ (self , x ):
118+ self .activations = []
119+ self .gradients = []
120+ self .grad_index = 0
121+ self .activation_to_layer = {}
122+ activation_index = 0
123+ for layer , (name , module ) in enumerate (self .model .features ._modules .items ()):
124+ x = module (x )
125+ if isinstance (module , torch .nn .modules .conv .Conv2d ):
126+ x .register_hook (self .save_gradient )
127+ self .activations += [x ]
128+ self .activation_to_layer [activation_index ] = layer
129+ activation_index += 1
130+
131+ feature_output = x .view (x .size (0 ), - 1 )
132+ final_output = self .model .classifier (feature_output )
133+ return final_output
134+
135+ def save_gradient (self , grad ):
136+ activation_index = len (self .activations ) - self .grad_index - 1
137+ activation = self .activations [activation_index ]
138+ values = torch .sum ((activation * grad ), dim = 0 ).sum (dim = 2 ).sum (dim = 3 )[0 , :, 0 , 0 ].data
139+ values = values / (activation .size (0 ) * activation .size (2 ) * activation .size (3 ))
140+
141+ if activation_index not in self .importance_values :
142+ self .importance_values [activation_index ] = torch .FloatTensor (activation .size (1 )).zero_ ().cuda ()
143+
144+ self .importance_values [activation_index ] += values
145+
146+ self .grad_index += 1
147+
148+ def find_k_minimum (self , num ):
149+ data = []
150+ for i in sorted (self .importance_values .keys ())[: - 1 ]: #TBD
151+ for j in range (self .importance_values [i ].size (0 )):
152+ data .append ((self .activation_to_layer [i ], j , self .importance_values [i ][j ]))
153+
154+ k_minimim = nsmallest (num , data , itemgetter (2 ))
155+ return k_minimim
156+
157+
158+ def get_filters_to_prune (train_data_loader , model , batch_size = 64 , num_batches = 15 ):
159+ model .eval ()
160+ extractor = ImportanceExtractor (model )
161+ print "In get_importance_criteron"
162+ t0 = time .time ()
163+
164+ criterion = torch .nn .CrossEntropyLoss ()
165+
166+ # for param in model.features.parameters():
167+ # param.requires_grad = True
168+
169+ for i , (batch , label ) in enumerate (train_data_loader ):
170+ # if i >= num_batches:
171+ # break
172+
173+ batch = batch .cuda ()
174+ label = label .cuda ()
175+
176+ output = extractor (Variable (batch ))
177+ loss = criterion (output , Variable (label ))
178+ model .zero_grad ()
179+ loss .backward ()
180+
181+ #Layerwise normalize
182+ for i in extractor .importance_values :
183+ v = torch .abs (extractor .importance_values [i ])
184+ v = v / np .sqrt (torch .sum (v * v ))
185+ extractor .importance_values [i ] = v .cpu ()
186+
187+ # for param in model.features.parameters():
188+ # param.requires_grad = False
189+
190+ filters_to_prune = extractor .find_k_minimum (256 )
191+ # After each of the k filters are prunned,
192+ # the filter index of the next filters change since the model is smaller.
193+
194+ filters_to_prune_per_layer = {}
195+ for (l , f , _ ) in filters_to_prune :
196+ if l not in filters_to_prune_per_layer :
197+ filters_to_prune_per_layer [l ] = []
198+ filters_to_prune_per_layer [l ].append (f )
199+
200+ for l in filters_to_prune_per_layer :
201+ filters_to_prune_per_layer [l ] = sorted (filters_to_prune_per_layer [l ])
202+ for i in range (len (filters_to_prune_per_layer [l ])):
203+ filters_to_prune_per_layer [l ][i ] = filters_to_prune_per_layer [l ][i ] - i
204+
205+ del filters_to_prune
206+ filters_to_prune = []
207+ for l in filters_to_prune_per_layer :
208+ for i in filters_to_prune_per_layer [l ]:
209+ filters_to_prune .append ((l , i ))
210+
211+ del extractor
212+
213+ return filters_to_prune
214+
215+ def num_filters (model ):
216+ filters = 0
217+ for name , module in model .features ._modules .items ():
218+ if isinstance (module , torch .nn .modules .conv .Conv2d ):
219+ filters = filters + module .out_channels
220+ return filters
221+
222+ def taylor_prune (train_path , test_path , model ):
223+ train_data_loader = dataset .loader (train_path , pin_memory = True , batch_size = 16 )
224+ test_data_loader = dataset .test_loader (test_path , pin_memory = True )
225+ model = model .cuda ()
226+ model .eval ()
227+ test (test_data_loader , model )
228+ for param in model .features .parameters ():
229+ param .requires_grad = True
230+
231+ for iteration in range (8 ):
232+ filters_to_prune = get_filters_to_prune (train_data_loader , model )
233+ model = model .cpu ()
234+ for layer_index , filter_index in filters_to_prune :
235+ model .features = prune_conv_layer (model .features , layer_index , filter_index )
236+
237+ model = model .cuda ()
238+ print "After pruning" , iteration , "Number of filters left" , num_filters (model )
239+ test (test_data_loader , model )
240+
241+ model .train ()
242+ for param in model .features .parameters ():
243+ param .requires_grad = True
244+
245+ optimizer = optim .SGD (model .parameters (), lr = 0.0001 , momentum = 0.9 )
246+ epochs = 5
247+ for i in range (epochs ):
248+ train_epoch (train_data_loader , model , optimizer )
249+
250+ print "Finished fine tuning next iteration"
251+ test (test_data_loader , model )
252+ torch .save (model , "model_prunned" )
253+
254+
255+ # for param in model.features.parameters():
256+ # param.requires_grad = False
257+
258+ # train_data_loader = dataset.loader(train_path, pin_memory = True, batch_size = 32)
259+ # optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.95)
260+ # print "Finished pruning. Now training more.. "
261+ # for i in range(25):
262+ # train_epoch(train_data_loader, model, optimizer)
263+ # test(test_data_loader, model)
264+ # torch.save(model, "model_prunned")
265+
266+
267+ def get_args ():
268+ parser = argparse .ArgumentParser ()
269+ parser .add_argument ("--train" , dest = "train" , action = "store_true" )
270+ parser .add_argument ("--random_prune" , dest = "random_prune" , action = "store_true" )
271+ parser .add_argument ("--train_path" , type = str , default = "train" )
272+ parser .add_argument ("--test_path" , type = str , default = "test" )
273+ parser .set_defaults (train = False )
274+ parser .set_defaults (random_prune = False )
275+ args = parser .parse_args ()
276+ return args
277+
278+ if __name__ == '__main__' :
279+ args = get_args ()
280+
281+ if args .train :
282+ train (args .train_path , args .test_path )
283+ elif args .random_prune :
284+ model = torch .load ("model" )
285+ model .eval ()
286+ model = model .cuda ()
287+ random_prune (model )
288+ else :
289+ model = torch .load ("model" )
290+ model = model .cuda ()
291+ taylor_prune (args .train_path , args .test_path , model )
0 commit comments