@@ -283,7 +283,6 @@ end function get_activation_by_name
283283 pure module subroutine backward(self, output)
284284 class(network), intent (in out ) :: self
285285 real , intent (in ) :: output(:)
286- real , allocatable :: gradient(:)
287286 integer :: n, num_layers
288287
289288 num_layers = size (self % layers)
@@ -296,18 +295,25 @@ pure module subroutine backward(self, output)
296295 ! Output layer; apply the loss function
297296 select type (this_layer = > self % layers(n) % p)
298297 type is (dense_layer)
299- gradient = quadratic_derivative(output, this_layer % output)
298+ call self % layers(n) % backward( &
299+ self % layers(n - 1 ), &
300+ quadratic_derivative(output, this_layer % output) &
301+ )
300302 end select
301303 else
302304 ! Hidden layer; take the gradient from the next layer
303305 select type (next_layer = > self % layers(n + 1 ) % p)
304306 type is (dense_layer)
305- gradient = next_layer % gradient
307+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
308+ type is (flatten_layer)
309+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
310+ type is (conv2d_layer)
311+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
312+ type is (maxpool2d_layer)
313+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
306314 end select
307315 end if
308316
309- call self % layers(n) % backward(self % layers(n - 1 ), gradient)
310-
311317 end do
312318
313319 end subroutine backward
0 commit comments