1212 use nf_reshape_layer, only: reshape3d_layer
1313 use nf_linear2d_layer, only: linear2d_layer
1414 use nf_self_attention_layer, only: self_attention_layer
15+ use nf_layernorm_layer, only: layernorm_layer
1516 use nf_optimizers, only: optimizer_base_type
1617
1718contains
@@ -46,7 +47,7 @@ pure module subroutine backward_1d(self, previous, gradient)
4647
4748 type is (flatten_layer)
4849
49- ! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
50+ ! Upstream layers permitted: input2d, input3d, conv2d, layernorm, maxpool2d
5051 select type (prev_layer = > previous % p)
5152 type is (input2d_layer)
5253 call this_layer % backward(prev_layer % output, gradient)
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6061 call this_layer % backward(prev_layer % output, gradient)
6162 type is (self_attention_layer)
6263 call this_layer % backward(prev_layer % output, gradient)
64+ type is (layernorm_layer)
65+ call this_layer % backward(prev_layer % output, gradient)
6366 end select
6467
6568 end select
@@ -84,6 +87,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8487 call this_layer % backward(prev_layer % output, gradient)
8588 type is (self_attention_layer)
8689 call this_layer % backward(prev_layer % output, gradient)
90+ type is (layernorm_layer)
91+ call this_layer % backward(prev_layer % output, gradient)
8792 end select
8893
8994 type is (self_attention_layer)
@@ -95,8 +100,18 @@ pure module subroutine backward_2d(self, previous, gradient)
95100 call this_layer % backward(prev_layer % output, gradient)
96101 type is (self_attention_layer)
97102 call this_layer % backward(prev_layer % output, gradient)
103+ type is (layernorm_layer)
104+ call this_layer % backward(prev_layer % output, gradient)
98105 end select
99106
107+ type is (layernorm_layer)
108+
109+ select type (prev_layer = > previous % p)
110+ type is (linear2d_layer)
111+ call this_layer % backward(prev_layer % output, gradient)
112+ type is (self_attention_layer)
113+ call this_layer % backward(prev_layer % output, gradient)
114+ end select
100115 end select
101116
102117 end subroutine backward_2d
@@ -234,6 +249,8 @@ module subroutine forward(self, input)
234249 call this_layer % forward(prev_layer % output)
235250 type is (linear2d_layer)
236251 call this_layer % forward(prev_layer % output)
252+ type is (layernorm_layer)
253+ call this_layer % forward(prev_layer % output)
237254 end select
238255
239256 type is (reshape3d_layer)
@@ -250,26 +267,40 @@ module subroutine forward(self, input)
250267
251268 type is (linear2d_layer)
252269
253- ! Upstream layers permitted: input2d, linear2d
270+ ! Upstream layers permitted: input2d, linear2d, self_attention, layernorm
254271 select type (prev_layer = > input % p)
255272 type is (input2d_layer)
256273 call this_layer % forward(prev_layer % output)
257274 type is (linear2d_layer)
258275 call this_layer % forward(prev_layer % output)
259276 type is (self_attention_layer)
260277 call this_layer % forward(prev_layer % output)
278+ type is (layernorm_layer)
279+ call this_layer % forward(prev_layer % output)
261280 end select
262281
263282 type is (self_attention_layer)
264283
265- ! Upstream layers permitted: input2d, linear2d
284+ ! Upstream layers permitted: input2d, linear2d, self_attention, layernorm
266285 select type (prev_layer = > input % p)
267286 type is (input2d_layer)
268287 call this_layer % forward(prev_layer % output)
269288 type is (linear2d_layer)
270289 call this_layer % forward(prev_layer % output)
271290 type is (self_attention_layer)
272291 call this_layer % forward(prev_layer % output)
292+ type is (layernorm_layer)
293+ call this_layer % forward(prev_layer % output)
294+ end select
295+
296+ type is (layernorm_layer)
297+
298+ ! Upstream layers permitted: linear2d, self_attention
299+ select type (prev_layer = > input % p)
300+ type is (linear2d_layer)
301+ call this_layer % forward(prev_layer % output)
302+ type is (self_attention_layer)
303+ call this_layer % forward(prev_layer % output)
273304 end select
274305
275306 end select
@@ -311,6 +342,8 @@ pure module subroutine get_output_2d(self, output)
311342 allocate (output, source= this_layer % output)
312343 type is (self_attention_layer)
313344 allocate (output, source= this_layer % output)
345+ type is (layernorm_layer)
346+ allocate (output, source= this_layer % output)
314347 class default
315348 error stop ' 2-d output can only be read from an input2d or linear2d layer.'
316349
@@ -354,8 +387,8 @@ impure elemental module subroutine init(self, input)
354387 call this_layer % init(input % layer_shape)
355388 end select
356389
357- ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
358- ! self_attention layers is not known until we receive an input layer.
390+ ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d,
391+ ! self_attention or layernorm layers is not known until we receive an input layer.
359392 select type (this_layer = > self % p)
360393 type is (conv2d_layer)
361394 self % layer_shape = shape (this_layer % output)
@@ -367,6 +400,8 @@ impure elemental module subroutine init(self, input)
367400 self % layer_shape = shape (this_layer % output)
368401 type is (self_attention_layer)
369402 self % layer_shape = shape (this_layer % output)
403+ type is (layernorm_layer)
404+ self % layer_shape = shape (this_layer % output)
370405 type is (maxpool2d_layer)
371406 self % layer_shape = shape (this_layer % output)
372407 end select
@@ -425,6 +460,8 @@ elemental module function get_num_params(self) result(num_params)
425460 num_params = this_layer % get_num_params()
426461 type is (self_attention_layer)
427462 num_params = this_layer % get_num_params()
463+ type is (layernorm_layer)
464+ num_params = this_layer % get_num_params()
428465 class default
429466 error stop ' Unknown layer type.'
430467 end select
@@ -458,6 +495,8 @@ module function get_params(self) result(params)
458495 params = this_layer % get_params()
459496 type is (self_attention_layer)
460497 params = this_layer % get_params()
498+ type is (layernorm_layer)
499+ params = this_layer % get_params()
461500 class default
462501 error stop ' Unknown layer type.'
463502 end select
@@ -491,6 +530,8 @@ module function get_gradients(self) result(gradients)
491530 gradients = this_layer % get_gradients()
492531 type is (self_attention_layer)
493532 gradients = this_layer % get_gradients()
533+ type is (layernorm_layer)
534+ gradients = this_layer % get_gradients()
494535 class default
495536 error stop ' Unknown layer type.'
496537 end select
@@ -549,6 +590,9 @@ module subroutine set_params(self, params)
549590 type is (self_attention_layer)
550591 call this_layer % set_params(params)
551592
593+ type is (layernorm_layer)
594+ call this_layer % set_params(params)
595+
552596 type is (maxpool2d_layer)
553597 ! No parameters to set.
554598 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments