From e2a60471ac4037191fb8e7fc7804c913972f8580 Mon Sep 17 00:00:00 2001 From: Jinen Setpal Date: Wed, 2 Sep 2020 03:57:58 +0530 Subject: [PATCH] compatibility, dynamic arguements, PEP formatting new run sequence: python visualize.py --- visualize.py | 52 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/visualize.py b/visualize.py index 5133971..ceea34f 100644 --- a/visualize.py +++ b/visualize.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F import os +import sys from torch.autograd import Variable import transforms as transforms @@ -23,12 +24,14 @@ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), ]) + def rgb2gray(rgb): - return np.dot(rgb[...,:3], [0.299, 0.587, 0.114]) + return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + -raw_img = io.imread('images/1.jpg') +raw_img = io.imread(sys.argv[1]) gray = rgb2gray(raw_img) -gray = resize(gray, (48,48), mode='symmetric').astype(np.uint8) +gray = resize(gray, (48, 48), mode='symmetric').astype(np.uint8) img = gray[:, :, np.newaxis] @@ -39,16 +42,21 @@ def rgb2gray(rgb): class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] net = VGG('VGG19') -checkpoint = torch.load(os.path.join('FER2013_VGG19', 'PrivateTest_model.t7')) +if torch.cuda.is_available(): + checkpoint = torch.load(os.path.join('FER2013_VGG19', 'PrivateTest_model.t7')) +else: + checkpoint = torch.load(os.path.join('FER2013_VGG19', 'PrivateTest_model.t7'), map_location=torch.device('cpu')) net.load_state_dict(checkpoint['net']) -net.cuda() net.eval() ncrops, c, h, w = np.shape(inputs) inputs = inputs.view(-1, c, h, w) -inputs = inputs.cuda() -inputs = Variable(inputs, volatile=True) +if torch.cuda.is_available(): + net.cuda() + inputs = inputs.cuda() +with torch.no_grad(): + inputs = Variable(inputs) outputs = net(inputs) outputs_avg = outputs.view(ncrops, -1).mean(0) # avg over crops @@ -56,29 +64,28 @@ def rgb2gray(rgb): score = F.softmax(outputs_avg) _, predicted = torch.max(outputs_avg.data, 0) -plt.rcParams['figure.figsize'] = (13.5,5.5) -axes=plt.subplot(1, 3, 1) +plt.rcParams['figure.figsize'] = (13.5, 5.5) +axes = plt.subplot(1, 3, 1) plt.imshow(raw_img) plt.xlabel('Input Image', fontsize=16) axes.set_xticks([]) axes.set_yticks([]) plt.tight_layout() - plt.subplots_adjust(left=0.05, bottom=0.2, right=0.95, top=0.9, hspace=0.02, wspace=0.3) plt.subplot(1, 3, 2) -ind = 0.1+0.6*np.arange(len(class_names)) # the x locations for the groups -width = 0.4 # the width of the bars: can also be len(x) sequence -color_list = ['red','orangered','darkorange','limegreen','darkgreen','royalblue','navy'] +ind = 0.1 + 0.6 * np.arange(len(class_names)) # the x locations for the groups +width = 0.4 # the width of the bars: can also be len(x) sequence +color_list = ['red', 'orangered', 'darkorange', 'limegreen', 'darkgreen', 'royalblue', 'navy'] for i in range(len(class_names)): plt.bar(ind[i], score.data.cpu().numpy()[i], width, color=color_list[i]) -plt.title("Classification results ",fontsize=20) -plt.xlabel(" Expression Category ",fontsize=16) -plt.ylabel(" Classification Score ",fontsize=16) +plt.title("Classification results ", fontsize=20) +plt.xlabel(" Expression Category ", fontsize=16) +plt.ylabel(" Classification Score ", fontsize=16) plt.xticks(ind, class_names, rotation=45, fontsize=14) -axes=plt.subplot(1, 3, 3) +axes = plt.subplot(1, 3, 3) emojis_img = io.imread('images/emojis/%s.png' % str(class_names[int(predicted.cpu().numpy())])) plt.imshow(emojis_img) plt.xlabel('Emoji Expression', fontsize=16) @@ -87,10 +94,11 @@ def rgb2gray(rgb): plt.tight_layout() # show emojis -#plt.show() -plt.savefig(os.path.join('images/results/l.png')) +# plt.show() +resdir = '/'.join(sys.argv[2].split('/')[:-1]) +if not os.path.isdir(resdir): + os.makedirs(resdir) +plt.savefig(sys.argv[2]) plt.close() -print("The Expression is %s" %str(class_names[int(predicted.cpu().numpy())])) - - +print("The Expression is %s" % str(class_names[int(predicted.cpu().numpy())]))