11from jax import random , numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
42from ngcsimlib .compartment import Compartment
5-
6- from .convSynapse import ConvSynapse
3+ from ngcsimlib .parser import compilable
74from ngclearn .utils .weight_distribution import initialize_params
8- from ngcsimlib .logger import info
9- from ngclearn .utils import tensorstats
105import ngclearn .utils .weight_distribution as dist
6+
7+ from ngclearn .components .synapses .convolution .convSynapse import ConvSynapse
8+
119from ngclearn .components .synapses .convolution .ngcconv import (_conv_same_transpose_padding ,
1210 _conv_valid_transpose_padding )
1311from ngclearn .components .synapses .convolution .ngcconv import (conv2d , _calc_dX_conv ,
@@ -93,7 +91,7 @@ def __init__(
9391
9492 ######################### set up compartments ##########################
9593 ## Compartment setup and shape computation
96- self .dWeights = Compartment (self .weights .value * 0 )
94+ self .dWeights = Compartment (self .weights .get () * 0 )
9795 self .dInputs = Compartment (jnp .zeros (self .in_shape ))
9896 self .preSpike = Compartment (jnp .zeros (self .in_shape ))
9997 self .preTrace = Compartment (jnp .zeros (self .in_shape ))
@@ -108,94 +106,99 @@ def __init__(
108106 k_size , k_size , n_in_chan , n_out_chan = self .shape
109107 if padding == "SAME" :
110108 self .antiPad = _conv_same_transpose_padding (
111- self .postSpike .value .shape [1 ],
109+ self .postSpike .get () .shape [1 ],
112110 self .x_size , k_size , stride )
113111 elif padding == "VALID" :
114112 self .antiPad = _conv_valid_transpose_padding (
115- self .postSpike .value .shape [1 ],
113+ self .postSpike .get () .shape [1 ],
116114 self .x_size , k_size , stride )
117115 ########################################################################
118116
119117 def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
120118 k_size , k_size , n_in_chan , n_out_chan = shape
121119 _x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
122- _d = conv2d (_x , weights .value , stride_size = stride , padding = padding ) * 0
120+ _d = conv2d (_x , weights .get () , stride_size = stride , padding = padding ) * 0
123121 _dK = _calc_dK_conv (_x , _d , stride_size = stride , padding = pad_args )
124122 ## get filter update correction
125- dx = _dK .shape [0 ] - weights .value .shape [0 ]
126- dy = _dK .shape [1 ] - weights .value .shape [1 ]
123+ dx = _dK .shape [0 ] - weights .get () .shape [0 ]
124+ dy = _dK .shape [1 ] - weights .get () .shape [1 ]
127125 #self.delta_shape = (dx, dy)
128126 self .delta_shape = (max (dx , 0 ), max (dy , 0 ))
129127 ## get input update correction
130- _dx = _calc_dX_conv (weights .value , _d , stride_size = stride ,
128+ _dx = _calc_dX_conv (weights .get () , _d , stride_size = stride ,
131129 anti_padding = pad_args )
132130 dx = (_dx .shape [1 ] - _x .shape [1 ])
133131 dy = (_dx .shape [2 ] - _x .shape [2 ])
134132 self .x_delta_shape = (dx , dy )
135133
136- @staticmethod
137- def _compute_update (
138- pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape , preSpike , preTrace , postSpike , postTrace
139- ):
134+ #@staticmethod
135+ def _compute_update (self ): #pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
140136 ## Compute long-term potentiation to filters
141137 dW_ltp = calc_dK_conv (
142- preTrace - pretrace_target , postSpike * Aplus , delta_shape = delta_shape , stride_size = stride , padding = pad_args
138+ self .preTrace .get () - self .pretrace_target , self .postSpike .get () * self .Aplus , delta_shape = self .delta_shape ,
139+ stride_size = self .stride , padding = self .pad_args
143140 )
144141 ## Compute long-term depression to filters
145142 dW_ltd = - calc_dK_conv (
146- preSpike , postTrace * Aminus , delta_shape = delta_shape , stride_size = stride , padding = pad_args
143+ self .preSpike .get (), self .postTrace .get () * self .Aminus , delta_shape = self .delta_shape ,
144+ stride_size = self .stride , padding = self .pad_args
147145 )
148146 dWeights = (dW_ltp + dW_ltd )
149147 return dWeights
150148
151- @transition (output_compartments = ["weights" , "dWeights" ])
152- @staticmethod
153- def evolve (
154- pretrace_target , Aplus , Aminus , w_decay , w_bound , stride , pad_args , delta_shape , preSpike , preTrace ,
155- postSpike , postTrace , weights , eta
156- ):
157- dWeights = TraceSTDPConvSynapse ._compute_update (
158- pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape , preSpike , preTrace , postSpike , postTrace
159- )
160- if w_decay > 0. : ## apply synaptic decay
161- weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
149+ # @transition(output_compartments=["weights", "dWeights"])
150+ # @staticmethod
151+ @compilable
152+ def evolve (self ):
153+ # pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace,
154+ # postSpike, postTrace, weights, eta
155+
156+ dWeights = self ._compute_update ()
157+ # dWeights = TraceSTDPConvSynapse._compute_update(
158+ # pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
159+ # )
160+ if self .w_decay > 0. : ## apply synaptic decay
161+ weights = self .weights .get () + dWeights * self .eta - self .weights .get () * self .w_decay ## conduct decayed STDP-ascent
162162 else :
163- weights = weights + dWeights * eta ## conduct STDP-ascent
163+ weights = self . weights . get () + dWeights * self . eta ## conduct STDP-ascent
164164 ## Apply any enforced filter constraints
165- if w_bound > 0. : ## enforce non-negativity
165+ if self . w_bound > 0. : ## enforce non-negativity
166166 eps = 0.01 # 0.001
167- weights = jnp .clip (weights , eps , w_bound - eps )
168- return weights , dWeights
169-
170- @transition (output_compartments = ["dInputs" ])
171- @staticmethod
172- def backtransmit (
173- x_size , shape , stride , padding , x_delta_shape , antiPad , postSpike , weights
174- ): ## action-backpropagating routine
167+ weights = jnp .clip (weights , eps , self .w_bound - eps )
168+
169+ self .weights .set (weights )
170+ self .dWeights .set (dWeights )
171+
172+ # @transition(output_compartments=["dInputs"])
173+ # @staticmethod
174+ @compilable
175+ def backtransmit (self ): # x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights
176+ ## action-backpropagating routine
175177 ## calc dInputs - adjustment w.r.t. input signal
176- k_size , k_size , n_in_chan , n_out_chan = shape
178+ k_size , k_size , n_in_chan , n_out_chan = self . shape
177179 # antiPad = None
178180 # if padding == "SAME":
179181 # antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size,
180182 # k_size, stride)
181183 # elif padding == "VALID":
182184 # antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
183185 # k_size, stride)
184- dInputs = calc_dX_conv (weights , postSpike , delta_shape = x_delta_shape , stride_size = stride , anti_padding = antiPad )
185- return dInputs
186-
187- @transition (output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" ])
188- @staticmethod
189- def reset (in_shape , out_shape ):
190- preVals = jnp .zeros (in_shape )
191- postVals = jnp .zeros (out_shape )
192- inputs = preVals
193- outputs = postVals
194- preSpike = preVals
195- postSpike = postVals
196- preTrace = preVals
197- postTrace = postVals
198- return inputs , outputs , preSpike , postSpike , preTrace , postTrace
186+ dInputs = calc_dX_conv (
187+ self .weights .get (), self .postSpike .get (), delta_shape = self .x_delta_shape , stride_size = self .stride ,
188+ anti_padding = self .antiPad
189+ )
190+ self .dInputs .set (dInputs )
191+
192+ @compilable
193+ def reset (self ): # in_shape, out_shape):
194+ preVals = jnp .zeros (self .in_shape .get ())
195+ postVals = jnp .zeros (self .out_shape .get ())
196+ self .inputs .set (preVals )
197+ self .outputs .set (postVals )
198+ self .preSpike .set (preVals )
199+ self .postSpike .set (postVals )
200+ self .preTrace .set (preVals )
201+ self .postTrace .set (postVals )
199202
200203 @classmethod
201204 def help (cls ): ## component help function
0 commit comments