Skip to content

Commit 3cca804

Browse files
author
Yongjian Feng
committed
first checkin
0 parents  commit 3cca804

15 files changed

+618
-0
lines changed

Adaline.m

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
classdef Adaline < matlab.graphics.chartcontainer.ChartContainer
2+
properties(Access=public)
3+
alpha;
4+
maxSteps;
5+
theshold;
6+
batchSize;
7+
8+
showBatch;
9+
showMiniBatch;
10+
showStochastic;
11+
12+
pauseSec;
13+
outfile; % animation gif
14+
showHighlight; % highlight dots used by minibatch/stochatic
15+
end
16+
17+
properties(Access=protected)
18+
TrainingData;
19+
20+
Class1Scatter;
21+
Class2Scatter;
22+
MiniBHighlight;
23+
StochHighlight;
24+
25+
BatchDecision;
26+
StochDecision;
27+
MiniBDecision;
28+
29+
CostBatchPlot;
30+
CostStochPlot;
31+
CostMiniBPlot;
32+
33+
TopAxes;
34+
BotAxes;
35+
36+
% weights
37+
Wbatch;
38+
W0batch;
39+
40+
Wstoch;
41+
W0stoch;
42+
43+
WminiB;
44+
W0miniB;
45+
end
46+
47+
properties(Access=protected)
48+
Xmin;
49+
Xmax;
50+
Ymin;
51+
Ymax;
52+
53+
Xlabel;
54+
Ylabel;
55+
DoneGraDes;
56+
end
57+
58+
methods(Access=protected)
59+
function setup(obj)
60+
tcl = getLayout(obj);
61+
tcl.GridSize = [2 1];
62+
obj.TopAxes = nexttile(tcl, 1);
63+
obj.BotAxes = nexttile(tcl, 2);
64+
65+
obj.Class1Scatter = scatter(obj.TopAxes, NaN, NaN, 'c*');
66+
hold(obj.TopAxes, 'on');
67+
obj.Class2Scatter = scatter(obj.TopAxes, NaN, NaN, 'mo');
68+
obj.MiniBHighlight = scatter(obj.TopAxes, NaN, NaN, 'MarkerFaceColor', 'b', 'MarkerFaceAlpha', 0.1, 'MarkerEdgeColor', 'b');
69+
obj.StochHighlight = scatter(obj.TopAxes, NaN, NaN, 'MarkerFaceColor', 'g', 'MarkerFaceAlpha', 0.1, 'MarkerEdgeColor', 'g');
70+
71+
obj.BatchDecision = plot(obj.TopAxes, NaN, NaN, 'r-');
72+
obj.StochDecision = plot(obj.TopAxes, NaN, NaN, 'g-');
73+
obj.MiniBDecision = plot(obj.TopAxes, NaN, NaN, 'b-');
74+
hold(obj.TopAxes, 'off');
75+
76+
obj.CostBatchPlot = plot(obj.BotAxes, NaN, NaN, 'r-');
77+
hold(obj.BotAxes, 'on');
78+
obj.CostStochPlot = plot(obj.BotAxes, NaN, NaN, 'g-');
79+
obj.CostMiniBPlot = plot(obj.BotAxes, NaN, NaN, 'b-');
80+
hold(obj.BotAxes, 'off');
81+
end
82+
83+
function update(obj)
84+
if ~isempty(obj.Wbatch) && obj.showBatch
85+
obj.updateDecisionBoundary(obj.BatchDecision, obj.Wbatch, obj.W0batch);
86+
end
87+
88+
if ~isempty(obj.Wstoch) && obj.showStochastic
89+
obj.updateDecisionBoundary(obj.StochDecision, obj.Wstoch, obj.W0stoch);
90+
end
91+
92+
if ~isempty(obj.WminiB) && obj.showMiniBatch
93+
obj.updateDecisionBoundary(obj.MiniBDecision, obj.WminiB, obj.W0miniB);
94+
end
95+
end
96+
97+
function initDecisionBound(obj)
98+
w = rand(2,1);
99+
w0 = rand(1);
100+
if obj.showBatch
101+
obj.Wbatch = w;
102+
obj.W0batch = w0;
103+
end
104+
105+
if obj.showStochastic
106+
obj.Wstoch = w;
107+
obj.W0stoch = w0;
108+
end
109+
110+
if obj.showMiniBatch
111+
obj.WminiB = w;
112+
obj.W0miniB = w0;
113+
end
114+
end
115+
116+
function updateDecisionBoundary(obj, DecisionBoundary, W, W0)
117+
if ~isempty(obj.Ymax)
118+
[X, Y] = Adaline.getline(W(1), W(2), W0, [obj.Xmin obj.Xmax obj.Ymin obj.Ymax]);
119+
DecisionBoundary.XData = X;
120+
DecisionBoundary.YData = Y;
121+
end
122+
end
123+
124+
function ret = substitute(~, X, W, W0)
125+
ret = X*W + W0;
126+
end
127+
128+
function [Wret, W0ret] = graDescent(obj, X, y, W, W0, CostPlot)
129+
ret = obj.substitute(X{:, :}, W, W0);
130+
errors = (y{:, :} - ret);
131+
Wret = W + obj.alpha*transpose(X{:, :})*errors;
132+
W0ret= W0 + obj.alpha*sum(errors);
133+
% scale it by lenght of y so as to get a fair comparison
134+
cost = transpose(errors)*errors/2/length(y{:, :});
135+
if isnan(CostPlot.XData)
136+
CostPlot.YData = cost;
137+
CostPlot.XData = 1;
138+
else
139+
CostPlot.YData = [CostPlot.YData, cost];
140+
CostPlot.XData = [CostPlot.XData, CostPlot.XData(end) + 1];
141+
end
142+
end
143+
144+
end
145+
methods(Static)
146+
% Utility functions
147+
function [X, Y] = getline(a, b, c, rng)
148+
% a line of ax+by+c=0
149+
XMIN=1;XMAX=2;YMIN=3;YMAX=4;
150+
gety = @(x) (-c-a*x)/b;
151+
152+
xtmp = gety(rng(YMIN));
153+
if xtmp < rng(XMIN)
154+
X(1) = xtmp;
155+
Y(1) = gety(xtmp);
156+
else
157+
X(1) = rng(XMIN);
158+
Y(1) = gety(X(1));
159+
end
160+
161+
xtmp = gety(rng(YMAX));
162+
if xtmp > rng(XMAX)
163+
X(2) = xtmp;
164+
Y(2) = gety(xtmp);
165+
else
166+
X(2) = rng(XMAX);
167+
Y(2) = gety(X(2));
168+
end
169+
end
170+
end
171+
172+
methods(Access=public)
173+
function obj = Adaline(csvFile)
174+
175+
obj.alpha = 0.005;
176+
obj.maxSteps = 200;
177+
obj.theshold = 0.00001;
178+
obj.batchSize = 20;
179+
180+
obj.pauseSec = 0.2;
181+
182+
% Preprocess data
183+
obj.preprocess(csvFile);
184+
185+
% axis labels
186+
xlabel(obj.TopAxes, [obj.Xlabel '(scaled)']);
187+
ylabel(obj.TopAxes, [obj.Ylabel '(scaled)']);
188+
xlabel(obj.BotAxes, 'Epochs');
189+
ylabel(obj.BotAxes, 'Sum-squared-error');
190+
191+
class1Data = obj.TrainingData(obj.TrainingData.cls==1, :);
192+
class2Data = obj.TrainingData(obj.TrainingData.cls==-1, :);
193+
194+
obj.Class1Scatter.XData = class1Data.Var1s;
195+
obj.Class1Scatter.YData = class1Data.Var2s;
196+
197+
obj.Class2Scatter.XData = class2Data.Var1s;
198+
obj.Class2Scatter.YData = class2Data.Var2s;
199+
200+
obj.Xmin = min(obj.TrainingData.Var1s);
201+
obj.Xmax = max(obj.TrainingData.Var1s);
202+
obj.Ymin = min(obj.TrainingData.Var2s);
203+
obj.Ymax = max(obj.TrainingData.Var2s);
204+
xlim(obj.TopAxes, [obj.Xmin obj.Xmax]);
205+
ylim(obj.TopAxes, [obj.Ymin obj.Ymax]);
206+
207+
% Default settings. Allow user to overwrite later
208+
obj.showBatch = true;
209+
obj.showStochastic = false;
210+
obj.showMiniBatch = false;
211+
obj.showHighlight = false;
212+
obj.outfile = []; % no animation output by default
213+
end
214+
215+
function preprocess(obj, csvFile)
216+
obj.TrainingData = readtable(csvFile);
217+
218+
obj.Xlabel = obj.TrainingData.Properties.VariableNames{1};
219+
obj.Ylabel = obj.TrainingData.Properties.VariableNames{2};
220+
221+
var1 = obj.TrainingData(:, obj.Xlabel).Variables;
222+
var1mean = mean(var1);
223+
var1std = std(var1);
224+
obj.TrainingData.Var1s = (var1 - var1mean)/var1std;
225+
226+
var2 = obj.TrainingData(:, obj.Ylabel).Variables;
227+
var2mean = mean(var2);
228+
var2std = std(var2);
229+
obj.TrainingData.Var2s = (var2 - var2mean)/var2std;
230+
end
231+
function setLegend(obj)
232+
plots = [];
233+
legends = {};
234+
if obj.showBatch
235+
plots(end+1) = obj.CostBatchPlot;
236+
legends{end+1} = 'Batch';
237+
end
238+
if obj.showMiniBatch
239+
plots(end+1) = obj.CostMiniBPlot;
240+
legends{end+1} = 'Mini Batch';
241+
end
242+
if obj.showStochastic
243+
plots(end+1) = obj.CostStochPlot;
244+
legends{end+1} = 'Stochastic';
245+
end
246+
legend(obj.BotAxes, plots, legends);
247+
end
248+
function animate(obj)
249+
% init decision boundary
250+
obj.initDecisionBound();
251+
xlim(obj.BotAxes, [0 obj.maxSteps]);
252+
253+
obj.setLegend();
254+
obj.DoneGraDes = [~obj.showBatch ~obj.showStochastic ~obj.showMiniBatch];
255+
256+
numSamples = height(obj.TrainingData);
257+
firstFrame = true;
258+
for i=1:obj.maxSteps
259+
if obj.showBatch && ~obj.DoneGraDes(1)
260+
[obj.Wbatch, obj.W0batch] = obj.graDescent(obj.TrainingData(:, {'Var1s', 'Var2s'}), obj.TrainingData(:, 'cls'), obj.Wbatch, obj.W0batch, obj.CostBatchPlot);
261+
262+
if length(obj.CostBatchPlot.YData) > 1
263+
obj.DoneGraDes(1) = abs(obj.CostBatchPlot.YData(end) - obj.CostBatchPlot.YData(end-1)) < obj.theshold;
264+
end
265+
end
266+
% minibatch gradient
267+
if obj.showMiniBatch && ~obj.DoneGraDes(3)
268+
% randomly pick mini batch training data
269+
idx = randsample([1:numSamples], obj.batchSize);
270+
batchTable = obj.TrainingData(idx, :);
271+
[obj.WminiB, obj.W0miniB] = obj.graDescent(batchTable(:, {'Var1s', 'Var2s'}), batchTable(:, 'cls'), obj.WminiB, obj.W0miniB, obj.CostMiniBPlot);
272+
273+
% highlight the dots being used in gradient descent
274+
if obj.showHighlight
275+
obj.MiniBHighlight.XData = batchTable.Var1s;
276+
obj.MiniBHighlight.YData = batchTable.Var2s;
277+
end
278+
if length(obj.CostMiniBPlot.YData) > 1
279+
obj.DoneGraDes(3) = abs(obj.CostMiniBPlot.YData(end) - obj.CostMiniBPlot.YData(end-1)) < obj.theshold;
280+
end
281+
end
282+
283+
if obj.showStochastic && ~obj.DoneGraDes(2)
284+
% randomly pick one for stochastic batch
285+
j = randi(100, 1);
286+
stochTable = obj.TrainingData(j, :);
287+
[obj.Wstoch, obj.W0stoch] = obj.graDescent(stochTable(:, {'Var1s', 'Var2s'}), stochTable(:, 'cls'), obj.Wstoch, obj.W0stoch, obj.CostStochPlot);
288+
if obj.showHighlight
289+
obj.StochHighlight.XData = stochTable.Var1s;
290+
obj.StochHighlight.YData = stochTable.Var2s;
291+
end
292+
293+
if length(obj.CostStochPlot.YData) > 1
294+
obj.DoneGraDes(2) = abs(obj.CostStochPlot.YData(end) - obj.CostStochPlot.YData(end-1)) < obj.theshold;
295+
end
296+
end
297+
298+
if all(obj.DoneGraDes)
299+
break;
300+
end
301+
if ~isempty(obj.outfile)
302+
[img, map] = rgb2ind(frame2im( getframe(gcf)),256);
303+
if firstFrame
304+
imwrite(img,map,obj.outfile,'gif','DelayTime',0.5);
305+
firstFrame = false;
306+
else
307+
imwrite(img,map,obj.outfile,'gif','writemode', 'append','delaytime', obj.pauseSec);
308+
end
309+
else
310+
pause(obj.pauseSec);
311+
end
312+
end
313+
end
314+
end
315+
end

0 commit comments

Comments
 (0)