Skip to content

Commit 202029d

Browse files
committed
First commit
0 parents  commit 202029d

File tree

3 files changed

+429
-0
lines changed

3 files changed

+429
-0
lines changed

dataset.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import torch
3+
import torch.backends.cudnn as cudnn
4+
import torch.nn as nn
5+
import torch.nn.parallel
6+
import torch.optim as optim
7+
import torch.utils.data as data
8+
import torchvision.datasets as datasets
9+
import torchvision.models as models
10+
import torchvision.transforms as transforms
11+
from PIL import Image
12+
import glob
13+
import os
14+
15+
def loader(path, batch_size=32, num_workers=4, pin_memory=True):
16+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17+
return data.DataLoader(
18+
datasets.ImageFolder(path,
19+
transforms.Compose([
20+
transforms.Scale(256),
21+
transforms.RandomSizedCrop(224),
22+
transforms.RandomHorizontalFlip(),
23+
transforms.ToTensor(),
24+
normalize,
25+
])),
26+
batch_size=batch_size,
27+
shuffle=True,
28+
num_workers=num_workers,
29+
pin_memory=pin_memory)
30+
31+
def test_loader(path, batch_size=32, num_workers=4, pin_memory=True):
32+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33+
return data.DataLoader(
34+
datasets.ImageFolder(path,
35+
transforms.Compose([
36+
transforms.Scale(256),
37+
transforms.CenterCrop(224),
38+
transforms.ToTensor(),
39+
normalize,
40+
])),
41+
batch_size=batch_size,
42+
shuffle=False,
43+
num_workers=num_workers,
44+
pin_memory=pin_memory)

finetune.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import torch
2+
from torch.autograd import Variable
3+
from torchvision import models
4+
import cv2
5+
import sys
6+
import numpy as np
7+
import torchvision
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
import torch.optim as optim
11+
import dataset
12+
from prune import *
13+
import argparse
14+
from operator import itemgetter
15+
from heapq import nsmallest
16+
import time
17+
18+
class Model(torch.nn.Module):
19+
def __init__(self, vgg_model):
20+
super(Model, self).__init__()
21+
self.features = vgg_model.features
22+
for param in self.features.parameters():
23+
param.requires_grad = False
24+
25+
self.classifier = nn.Sequential(
26+
nn.Dropout(),
27+
nn.Linear(25088, 4096),
28+
nn.ReLU(inplace=True),
29+
nn.Dropout(),
30+
nn.Linear(4096, 4096),
31+
nn.ReLU(inplace=True),
32+
nn.Linear(4096, 2))
33+
34+
def forward(self, x):
35+
x = self.features(x)
36+
x = x.view(x.size(0), -1)
37+
x = self.classifier(x)
38+
return x
39+
40+
def train_batch(model, optimizer, criterion, batch, label):
41+
batch = batch.cuda()
42+
label = label.cuda()
43+
output = model(Variable(batch))
44+
loss = criterion(output, Variable(label))
45+
model.zero_grad()
46+
loss.backward()
47+
optimizer.step()
48+
return loss
49+
50+
def train_epoch(data_loader, model, optimizer):
51+
model.train()
52+
criterion = torch.nn.CrossEntropyLoss()
53+
for i, (batch, label) in enumerate(data_loader):
54+
loss = train_batch(model, optimizer, criterion, batch, label)
55+
56+
def test(data_loader, model):
57+
model.eval()
58+
correct = 0
59+
total = 0
60+
for i, (batch, label) in enumerate(data_loader):
61+
batch = batch.cuda()
62+
output = model(Variable(batch))
63+
pred = output.data.max(1)[1]
64+
correct += pred.cpu().eq(label).sum()
65+
total += label.size(0)
66+
67+
print "Accuracy :", float(correct) / total
68+
69+
def train(train_path, test_path):
70+
model = Model(torchvision.models.vgg16(pretrained=True))
71+
model = model.cuda()
72+
data_loader = dataset.loader(train_path, pin_memory = True)
73+
test_data_loader = dataset.test_loader(test_path, pin_memory = True)
74+
optimizer = optim.SGD(model.classifier.parameters(), lr=0.0001, momentum=0.9)
75+
76+
test(test_data_loader, model)
77+
print "Starting.. "
78+
epochs = 10
79+
for i in range(epochs):
80+
train_epoch(data_loader, model, optimizer)
81+
test(test_data_loader, model)
82+
torch.save(model, "model")
83+
84+
def random_prune(model):
85+
test_data_loader = dataset.test_loader(args.test_path, pin_memory = True, batch_size=64)
86+
test(test_data_loader, model)
87+
model = model.cpu()
88+
model.eval()
89+
90+
l = len(model.features._modules.items())
91+
print "Total number of layers", l
92+
items = [x for x in enumerate(model.features._modules.items())]
93+
items = items[::-5]
94+
for i, (name, module) in items:
95+
if i >= l - 5:
96+
continue
97+
if isinstance(module, torch.nn.modules.conv.Conv2d):
98+
filters = model.features._modules[name].weight.size(0)
99+
t0 = time.time()
100+
for _ in range(filters//2):
101+
model.features = prune_conv_layer(model.features, i, 0)
102+
103+
print "After pruning.. "
104+
model = model.cuda()
105+
test(test_data_loader, model)
106+
107+
class ImportanceExtractor:
108+
def __init__(self, model):
109+
self.model = model
110+
num_convs = 0
111+
for name, module in self.model.features._modules.items():
112+
if isinstance(module, torch.nn.modules.conv.Conv2d):
113+
num_convs = num_convs + 1
114+
115+
self.importance_values = {}
116+
117+
def __call__(self, x):
118+
self.activations = []
119+
self.gradients = []
120+
self.grad_index = 0
121+
self.activation_to_layer = {}
122+
activation_index = 0
123+
for layer, (name, module) in enumerate(self.model.features._modules.items()):
124+
x = module(x)
125+
if isinstance(module, torch.nn.modules.conv.Conv2d):
126+
x.register_hook(self.save_gradient)
127+
self.activations += [x]
128+
self.activation_to_layer[activation_index] = layer
129+
activation_index += 1
130+
131+
feature_output = x.view(x.size(0), -1)
132+
final_output = self.model.classifier(feature_output)
133+
return final_output
134+
135+
def save_gradient(self, grad):
136+
activation_index = len(self.activations) - self.grad_index - 1
137+
activation = self.activations[activation_index]
138+
values = torch.sum((activation * grad), dim = 0).sum(dim=2).sum(dim=3)[0, :, 0, 0].data
139+
values = values / (activation.size(0) * activation.size(2) * activation.size(3))
140+
141+
if activation_index not in self.importance_values:
142+
self.importance_values[activation_index] = torch.FloatTensor(activation.size(1)).zero_().cuda()
143+
144+
self.importance_values[activation_index] += values
145+
146+
self.grad_index += 1
147+
148+
def find_k_minimum(self, num):
149+
data = []
150+
for i in sorted(self.importance_values.keys())[: -1]: #TBD
151+
for j in range(self.importance_values[i].size(0)):
152+
data.append((self.activation_to_layer[i], j, self.importance_values[i][j]))
153+
154+
k_minimim = nsmallest(num, data, itemgetter(2))
155+
return k_minimim
156+
157+
158+
def get_filters_to_prune(train_data_loader, model, batch_size = 64, num_batches = 15):
159+
model.eval()
160+
extractor = ImportanceExtractor(model)
161+
print "In get_importance_criteron"
162+
t0 = time.time()
163+
164+
criterion = torch.nn.CrossEntropyLoss()
165+
166+
# for param in model.features.parameters():
167+
# param.requires_grad = True
168+
169+
for i, (batch, label) in enumerate(train_data_loader):
170+
# if i >= num_batches:
171+
# break
172+
173+
batch = batch.cuda()
174+
label = label.cuda()
175+
176+
output = extractor(Variable(batch))
177+
loss = criterion(output, Variable(label))
178+
model.zero_grad()
179+
loss.backward()
180+
181+
#Layerwise normalize
182+
for i in extractor.importance_values:
183+
v = torch.abs(extractor.importance_values[i])
184+
v = v / np.sqrt(torch.sum(v * v))
185+
extractor.importance_values[i] = v.cpu()
186+
187+
# for param in model.features.parameters():
188+
# param.requires_grad = False
189+
190+
filters_to_prune = extractor.find_k_minimum(256)
191+
# After each of the k filters are prunned,
192+
# the filter index of the next filters change since the model is smaller.
193+
194+
filters_to_prune_per_layer = {}
195+
for (l, f, _) in filters_to_prune:
196+
if l not in filters_to_prune_per_layer:
197+
filters_to_prune_per_layer[l] = []
198+
filters_to_prune_per_layer[l].append(f)
199+
200+
for l in filters_to_prune_per_layer:
201+
filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l])
202+
for i in range(len(filters_to_prune_per_layer[l])):
203+
filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i
204+
205+
del filters_to_prune
206+
filters_to_prune = []
207+
for l in filters_to_prune_per_layer:
208+
for i in filters_to_prune_per_layer[l]:
209+
filters_to_prune.append((l, i))
210+
211+
del extractor
212+
213+
return filters_to_prune
214+
215+
def num_filters(model):
216+
filters = 0
217+
for name, module in model.features._modules.items():
218+
if isinstance(module, torch.nn.modules.conv.Conv2d):
219+
filters = filters + module.out_channels
220+
return filters
221+
222+
def taylor_prune(train_path, test_path, model):
223+
train_data_loader = dataset.loader(train_path, pin_memory = True, batch_size = 16)
224+
test_data_loader = dataset.test_loader(test_path, pin_memory = True)
225+
model = model.cuda()
226+
model.eval()
227+
test(test_data_loader, model)
228+
for param in model.features.parameters():
229+
param.requires_grad = True
230+
231+
for iteration in range(8):
232+
filters_to_prune = get_filters_to_prune(train_data_loader, model)
233+
model = model.cpu()
234+
for layer_index, filter_index in filters_to_prune:
235+
model.features = prune_conv_layer(model.features, layer_index, filter_index)
236+
237+
model = model.cuda()
238+
print "After pruning", iteration, "Number of filters left", num_filters(model)
239+
test(test_data_loader, model)
240+
241+
model.train()
242+
for param in model.features.parameters():
243+
param.requires_grad = True
244+
245+
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
246+
epochs = 5
247+
for i in range(epochs):
248+
train_epoch(train_data_loader, model, optimizer)
249+
250+
print "Finished fine tuning next iteration"
251+
test(test_data_loader, model)
252+
torch.save(model, "model_prunned")
253+
254+
255+
# for param in model.features.parameters():
256+
# param.requires_grad = False
257+
258+
# train_data_loader = dataset.loader(train_path, pin_memory = True, batch_size = 32)
259+
# optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.95)
260+
# print "Finished pruning. Now training more.. "
261+
# for i in range(25):
262+
# train_epoch(train_data_loader, model, optimizer)
263+
# test(test_data_loader, model)
264+
# torch.save(model, "model_prunned")
265+
266+
267+
def get_args():
268+
parser = argparse.ArgumentParser()
269+
parser.add_argument("--train", dest="train", action="store_true")
270+
parser.add_argument("--random_prune", dest="random_prune", action="store_true")
271+
parser.add_argument("--train_path", type = str, default = "train")
272+
parser.add_argument("--test_path", type = str, default = "test")
273+
parser.set_defaults(train=False)
274+
parser.set_defaults(random_prune=False)
275+
args = parser.parse_args()
276+
return args
277+
278+
if __name__ == '__main__':
279+
args = get_args()
280+
281+
if args.train:
282+
train(args.train_path, args.test_path)
283+
elif args.random_prune:
284+
model = torch.load("model")
285+
model.eval()
286+
model = model.cuda()
287+
random_prune(model)
288+
else:
289+
model = torch.load("model")
290+
model = model.cuda()
291+
taylor_prune(args.train_path, args.test_path, model)

0 commit comments

Comments
 (0)