1+ classdef Adaline < matlab .graphics .chartcontainer .ChartContainer
2+ properties (Access = public )
3+ alpha ;
4+ maxSteps ;
5+ theshold ;
6+ batchSize ;
7+
8+ showBatch ;
9+ showMiniBatch ;
10+ showStochastic ;
11+
12+ pauseSec ;
13+ outfile ; % animation gif
14+ showHighlight ; % highlight dots used by minibatch/stochatic
15+ end
16+
17+ properties (Access = protected )
18+ TrainingData ;
19+
20+ Class1Scatter ;
21+ Class2Scatter ;
22+ MiniBHighlight ;
23+ StochHighlight ;
24+
25+ BatchDecision ;
26+ StochDecision ;
27+ MiniBDecision ;
28+
29+ CostBatchPlot ;
30+ CostStochPlot ;
31+ CostMiniBPlot ;
32+
33+ TopAxes ;
34+ BotAxes ;
35+
36+ % weights
37+ Wbatch ;
38+ W0batch ;
39+
40+ Wstoch ;
41+ W0stoch ;
42+
43+ WminiB ;
44+ W0miniB ;
45+ end
46+
47+ properties (Access = protected )
48+ Xmin ;
49+ Xmax ;
50+ Ymin ;
51+ Ymax ;
52+
53+ Xlabel ;
54+ Ylabel ;
55+ DoneGraDes ;
56+ end
57+
58+ methods (Access = protected )
59+ function setup(obj )
60+ tcl = getLayout(obj );
61+ tcl.GridSize = [2 1 ];
62+ obj.TopAxes = nexttile(tcl , 1 );
63+ obj.BotAxes = nexttile(tcl , 2 );
64+
65+ obj.Class1Scatter = scatter(obj .TopAxes , NaN , NaN , ' c*' );
66+ hold(obj .TopAxes , ' on' );
67+ obj.Class2Scatter = scatter(obj .TopAxes , NaN , NaN , ' mo' );
68+ obj.MiniBHighlight = scatter(obj .TopAxes , NaN , NaN , ' MarkerFaceColor' , ' b' , ' MarkerFaceAlpha' , 0.1 , ' MarkerEdgeColor' , ' b' );
69+ obj.StochHighlight = scatter(obj .TopAxes , NaN , NaN , ' MarkerFaceColor' , ' g' , ' MarkerFaceAlpha' , 0.1 , ' MarkerEdgeColor' , ' g' );
70+
71+ obj.BatchDecision = plot(obj .TopAxes , NaN , NaN , ' r-' );
72+ obj.StochDecision = plot(obj .TopAxes , NaN , NaN , ' g-' );
73+ obj.MiniBDecision = plot(obj .TopAxes , NaN , NaN , ' b-' );
74+ hold(obj .TopAxes , ' off' );
75+
76+ obj.CostBatchPlot = plot(obj .BotAxes , NaN , NaN , ' r-' );
77+ hold(obj .BotAxes , ' on' );
78+ obj.CostStochPlot = plot(obj .BotAxes , NaN , NaN , ' g-' );
79+ obj.CostMiniBPlot = plot(obj .BotAxes , NaN , NaN , ' b-' );
80+ hold(obj .BotAxes , ' off' );
81+ end
82+
83+ function update(obj )
84+ if ~isempty(obj .Wbatch ) && obj .showBatch
85+ obj .updateDecisionBoundary(obj .BatchDecision , obj .Wbatch , obj .W0batch );
86+ end
87+
88+ if ~isempty(obj .Wstoch ) && obj .showStochastic
89+ obj .updateDecisionBoundary(obj .StochDecision , obj .Wstoch , obj .W0stoch );
90+ end
91+
92+ if ~isempty(obj .WminiB ) && obj .showMiniBatch
93+ obj .updateDecisionBoundary(obj .MiniBDecision , obj .WminiB , obj .W0miniB );
94+ end
95+ end
96+
97+ function initDecisionBound(obj )
98+ w = rand(2 ,1 );
99+ w0 = rand(1 );
100+ if obj .showBatch
101+ obj.Wbatch = w ;
102+ obj.W0batch = w0 ;
103+ end
104+
105+ if obj .showStochastic
106+ obj.Wstoch = w ;
107+ obj.W0stoch = w0 ;
108+ end
109+
110+ if obj .showMiniBatch
111+ obj.WminiB = w ;
112+ obj.W0miniB = w0 ;
113+ end
114+ end
115+
116+ function updateDecisionBoundary(obj , DecisionBoundary , W , W0 )
117+ if ~isempty(obj .Ymax )
118+ [X , Y ] = Adaline .getline(W(1 ), W(2 ), W0 , [obj .Xmin obj .Xmax obj .Ymin obj .Ymax ]);
119+ DecisionBoundary.XData = X ;
120+ DecisionBoundary.YData = Y ;
121+ end
122+ end
123+
124+ function ret = substitute(~, X , W , W0 )
125+ ret = X * W + W0 ;
126+ end
127+
128+ function [Wret , W0ret ] = graDescent(obj , X , y , W , W0 , CostPlot )
129+ ret = obj .substitute(X{: , : }, W , W0 );
130+ errors = (y{: , : } - ret );
131+ Wret = W + obj .alpha * transpose(X{: , : })*errors ;
132+ W0ret= W0 + obj .alpha * sum(errors );
133+ % scale it by lenght of y so as to get a fair comparison
134+ cost = transpose(errors )*errors / 2 / length(y{: , : });
135+ if isnan(CostPlot .XData )
136+ CostPlot.YData = cost ;
137+ CostPlot.XData = 1 ;
138+ else
139+ CostPlot.YData = [CostPlot .YData , cost ];
140+ CostPlot.XData = [CostPlot .XData , CostPlot .XData(end ) + 1 ];
141+ end
142+ end
143+
144+ end
145+ methods (Static )
146+ % Utility functions
147+ function [X , Y ] = getline(a , b , c , rng )
148+ % a line of ax+by+c=0
149+ XMIN= 1 ;XMAX= 2 ;YMIN= 3 ;YMAX= 4 ;
150+ gety = @(x ) (-c - a * x )/b ;
151+
152+ xtmp = gety(rng(YMIN ));
153+ if xtmp < rng(XMIN )
154+ X(1 ) = xtmp ;
155+ Y(1 ) = gety(xtmp );
156+ else
157+ X(1 ) = rng(XMIN );
158+ Y(1 ) = gety(X(1 ));
159+ end
160+
161+ xtmp = gety(rng(YMAX ));
162+ if xtmp > rng(XMAX )
163+ X(2 ) = xtmp ;
164+ Y(2 ) = gety(xtmp );
165+ else
166+ X(2 ) = rng(XMAX );
167+ Y(2 ) = gety(X(2 ));
168+ end
169+ end
170+ end
171+
172+ methods (Access = public )
173+ function obj = Adaline(csvFile )
174+
175+ obj.alpha = 0.005 ;
176+ obj.maxSteps = 200 ;
177+ obj.theshold = 0.00001 ;
178+ obj.batchSize = 20 ;
179+
180+ obj.pauseSec = 0.2 ;
181+
182+ % Preprocess data
183+ obj .preprocess(csvFile );
184+
185+ % axis labels
186+ xlabel(obj .TopAxes , [obj .Xlabel ' (scaled)' ]);
187+ ylabel(obj .TopAxes , [obj .Ylabel ' (scaled)' ]);
188+ xlabel(obj .BotAxes , ' Epochs' );
189+ ylabel(obj .BotAxes , ' Sum-squared-error' );
190+
191+ class1Data = obj .TrainingData(obj .TrainingData .cls == 1 , : );
192+ class2Data = obj .TrainingData(obj .TrainingData .cls ==-1 , : );
193+
194+ obj.Class1Scatter.XData = class1Data .Var1s ;
195+ obj.Class1Scatter.YData = class1Data .Var2s ;
196+
197+ obj.Class2Scatter.XData = class2Data .Var1s ;
198+ obj.Class2Scatter.YData = class2Data .Var2s ;
199+
200+ obj.Xmin = min(obj .TrainingData .Var1s );
201+ obj.Xmax = max(obj .TrainingData .Var1s );
202+ obj.Ymin = min(obj .TrainingData .Var2s );
203+ obj.Ymax = max(obj .TrainingData .Var2s );
204+ xlim(obj .TopAxes , [obj .Xmin obj .Xmax ]);
205+ ylim(obj .TopAxes , [obj .Ymin obj .Ymax ]);
206+
207+ % Default settings. Allow user to overwrite later
208+ obj.showBatch = true ;
209+ obj.showStochastic = false ;
210+ obj.showMiniBatch = false ;
211+ obj.showHighlight = false ;
212+ obj.outfile = []; % no animation output by default
213+ end
214+
215+ function preprocess(obj , csvFile )
216+ obj.TrainingData = readtable(csvFile );
217+
218+ obj.Xlabel = obj.TrainingData.Properties.VariableNames{1 };
219+ obj.Ylabel = obj.TrainingData.Properties.VariableNames{2 };
220+
221+ var1 = obj .TrainingData(: , obj .Xlabel ).Variables;
222+ var1mean = mean(var1 );
223+ var1std = std(var1 );
224+ obj.TrainingData.Var1s = (var1 - var1mean )/var1std ;
225+
226+ var2 = obj .TrainingData(: , obj .Ylabel ).Variables;
227+ var2mean = mean(var2 );
228+ var2std = std(var2 );
229+ obj.TrainingData.Var2s = (var2 - var2mean )/var2std ;
230+ end
231+ function setLegend(obj )
232+ plots = [];
233+ legends = {};
234+ if obj .showBatch
235+ plots(end + 1 ) = obj .CostBatchPlot ;
236+ legends{end + 1 } = ' Batch' ;
237+ end
238+ if obj .showMiniBatch
239+ plots(end + 1 ) = obj .CostMiniBPlot ;
240+ legends{end + 1 } = ' Mini Batch' ;
241+ end
242+ if obj .showStochastic
243+ plots(end + 1 ) = obj .CostStochPlot ;
244+ legends{end + 1 } = ' Stochastic' ;
245+ end
246+ legend(obj .BotAxes , plots , legends );
247+ end
248+ function animate(obj )
249+ % init decision boundary
250+ obj .initDecisionBound();
251+ xlim(obj .BotAxes , [0 obj .maxSteps ]);
252+
253+ obj .setLegend();
254+ obj.DoneGraDes = [~obj .showBatch ~obj .showStochastic ~obj .showMiniBatch ];
255+
256+ numSamples = height(obj .TrainingData );
257+ firstFrame = true ;
258+ for i= 1 : obj .maxSteps
259+ if obj .showBatch && ~obj .DoneGraDes(1 )
260+ [obj .Wbatch , obj .W0batch ] = obj .graDescent(obj .TrainingData(: , {' Var1s' , ' Var2s' }), obj .TrainingData(: , ' cls' ), obj .Wbatch , obj .W0batch , obj .CostBatchPlot );
261+
262+ if length(obj .CostBatchPlot .YData ) > 1
263+ obj .DoneGraDes(1 ) = abs(obj .CostBatchPlot .YData(end ) - obj .CostBatchPlot .YData(end - 1 )) < obj .theshold ;
264+ end
265+ end
266+ % minibatch gradient
267+ if obj .showMiniBatch && ~obj .DoneGraDes(3 )
268+ % randomly pick mini batch training data
269+ idx = randsample([1 : numSamples ], obj .batchSize );
270+ batchTable = obj .TrainingData(idx , : );
271+ [obj .WminiB , obj .W0miniB ] = obj .graDescent(batchTable(: , {' Var1s' , ' Var2s' }), batchTable(: , ' cls' ), obj .WminiB , obj .W0miniB , obj .CostMiniBPlot );
272+
273+ % highlight the dots being used in gradient descent
274+ if obj .showHighlight
275+ obj.MiniBHighlight.XData = batchTable .Var1s ;
276+ obj.MiniBHighlight.YData = batchTable .Var2s ;
277+ end
278+ if length(obj .CostMiniBPlot .YData ) > 1
279+ obj .DoneGraDes(3 ) = abs(obj .CostMiniBPlot .YData(end ) - obj .CostMiniBPlot .YData(end - 1 )) < obj .theshold ;
280+ end
281+ end
282+
283+ if obj .showStochastic && ~obj .DoneGraDes(2 )
284+ % randomly pick one for stochastic batch
285+ j = randi(100 , 1 );
286+ stochTable = obj .TrainingData(j , : );
287+ [obj .Wstoch , obj .W0stoch ] = obj .graDescent(stochTable(: , {' Var1s' , ' Var2s' }), stochTable(: , ' cls' ), obj .Wstoch , obj .W0stoch , obj .CostStochPlot );
288+ if obj .showHighlight
289+ obj.StochHighlight.XData = stochTable .Var1s ;
290+ obj.StochHighlight.YData = stochTable .Var2s ;
291+ end
292+
293+ if length(obj .CostStochPlot .YData ) > 1
294+ obj .DoneGraDes(2 ) = abs(obj .CostStochPlot .YData(end ) - obj .CostStochPlot .YData(end - 1 )) < obj .theshold ;
295+ end
296+ end
297+
298+ if all(obj .DoneGraDes )
299+ break ;
300+ end
301+ if ~isempty(obj .outfile )
302+ [img , map ] = rgb2ind(frame2im( getframe(gcf )),256 );
303+ if firstFrame
304+ imwrite(img ,map ,obj .outfile ,' gif' ,' DelayTime' ,0.5 );
305+ firstFrame = false ;
306+ else
307+ imwrite(img ,map ,obj .outfile ,' gif' ,' writemode' , ' append' ,' delaytime' , obj .pauseSec );
308+ end
309+ else
310+ pause(obj .pauseSec );
311+ end
312+ end
313+ end
314+ end
315+ end
0 commit comments