11from jax import random , numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
42from ngcsimlib .compartment import Compartment
5-
6- from .deconvSynapse import DeconvSynapse
3+ from ngcsimlib .parser import compilable
74from ngclearn .utils .weight_distribution import initialize_params
8- from ngcsimlib .logger import info
9- from ngclearn .utils import tensorstats
105import ngclearn .utils .weight_distribution as dist
6+
7+ from ngclearn .components .synapses .convolution .deconvSynapse import DeconvSynapse
8+
119from ngclearn .components .synapses .convolution .ngcconv import (deconv2d , _calc_dX_deconv ,
1210 _calc_dK_deconv , calc_dX_deconv ,
1311 calc_dK_deconv )
14- from ngclearn .utils .optim import get_opt_init_fn , get_opt_step_fn
1512
1613class TraceSTDPDeconvSynapse (DeconvSynapse ): ## trace-based STDP deconvolutional cable
1714 """
@@ -92,7 +89,7 @@ def __init__(
9289
9390 ######################### set up compartments ##########################
9491 ## Compartment setup and shape computation
95- self .dWeights = Compartment (self .weights .value * 0 )
92+ self .dWeights = Compartment (self .weights .get () * 0 )
9693 self .dInputs = Compartment (jnp .zeros (self .in_shape ))
9794 self .preSpike = Compartment (jnp .zeros (self .in_shape ))
9895 self .preTrace = Compartment (jnp .zeros (self .in_shape ))
@@ -108,76 +105,73 @@ def __init__(
108105 def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
109106 k_size , k_size , n_in_chan , n_out_chan = shape
110107 _x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
111- _d = deconv2d (_x , self .weights .value , stride_size = self .stride ,
108+ _d = deconv2d (_x , self .weights .get () , stride_size = self .stride ,
112109 padding = self .padding ) * 0
113110 _dK = _calc_dK_deconv (_x , _d , stride_size = self .stride , out_size = k_size )
114111 ## get filter update correction
115- dx = _dK .shape [0 ] - self .weights .value .shape [0 ]
116- dy = _dK .shape [1 ] - self .weights .value .shape [1 ]
112+ dx = _dK .shape [0 ] - self .weights .get () .shape [0 ]
113+ dy = _dK .shape [1 ] - self .weights .get () .shape [1 ]
117114 self .delta_shape = (abs (dx ), abs (dy ))
118115
119116 ## get input update correction
120- _dx = _calc_dX_deconv (self .weights .value , _d , stride_size = self .stride ,
117+ _dx = _calc_dX_deconv (self .weights .get () , _d , stride_size = self .stride ,
121118 padding = self .padding )
122119 dx = (_dx .shape [1 ] - _x .shape [1 ]) # abs()
123120 dy = (_dx .shape [2 ] - _x .shape [2 ])
124121 self .x_delta_shape = (dx , dy )
125122
126- @staticmethod
127- def _compute_update (
128- pretrace_target , Aplus , Aminus , shape , stride , padding , delta_shape , preSpike , preTrace , postSpike , postTrace
129- ):
130- k_size , k_size , n_in_chan , n_out_chan = shape
123+ def _compute_update (self ):
124+ k_size , k_size , n_in_chan , n_out_chan = self .shape
131125 ## calc dFilters
132- dW_ltp = calc_dK_deconv (preTrace - pretrace_target , postSpike * Aplus ,
133- delta_shape = delta_shape , stride_size = stride ,
134- out_size = k_size , padding = padding )
135- dW_ltd = - calc_dK_deconv (preSpike , postTrace * Aminus ,
136- delta_shape = delta_shape , stride_size = stride ,
137- out_size = k_size , padding = padding )
126+ dW_ltp = calc_dK_deconv (
127+ self .preTrace .get () - self .pretrace_target , self .postSpike .get () * self .Aplus ,
128+ delta_shape = self .delta_shape , stride_size = self .stride , out_size = k_size , padding = self .padding
129+ )
130+ dW_ltd = - calc_dK_deconv (
131+ self .preSpike .get (), self .postTrace .get () * self .Aminus , delta_shape = self .delta_shape ,
132+ stride_size = self .stride , out_size = k_size , padding = self .padding
133+ )
138134 dWeights = (dW_ltp + dW_ltd )
139135 return dWeights
140136
141- @transition (output_compartments = ["weights" , "dWeights" ])
142- @staticmethod
143- def evolve (
144- pretrace_target , Aplus , Aminus , w_decay , w_bound , shape , stride , padding , delta_shape , preSpike , preTrace ,
145- postSpike , postTrace , weights , eta
146- ):
147- dWeights = TraceSTDPDeconvSynapse ._compute_update (
148- pretrace_target , Aplus , Aminus , shape , stride , padding , delta_shape ,
149- preSpike , preTrace , postSpike , postTrace
150- )
151- if w_decay > 0. : ## apply synaptic decay
152- weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
137+ @compilable
138+ def evolve (self ):
139+ dWeights = self ._compute_update ()
140+ # dWeights = TraceSTDPDeconvSynapse._compute_update(
141+ # pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape,
142+ # preSpike, preTrace, postSpike, postTrace
143+ # )
144+ if self .w_decay > 0. : ## apply synaptic decay and conduct decayed STDP-ascent
145+ weights = self .weights .get () + dWeights * self .eta - self .weights .get () * self .w_decay
153146 else :
154- weights = weights + dWeights * eta ## conduct STDP-ascent
147+ weights = self . weights . get () + dWeights * self . eta ## conduct STDP-ascent
155148 ## Apply any enforced filter constraints
156- if w_bound > 0. : ## enforce non-negativity
149+ if self . w_bound > 0. : ## enforce non-negativity
157150 eps = 0.01 # 0.001
158- weights = jnp .clip (weights , eps , w_bound - eps )
159- return weights , dWeights
151+ weights = jnp .clip (weights , eps , self .w_bound - eps )
160152
161- @transition (output_compartments = ["dInputs" ])
162- @staticmethod
163- def backtransmit (stride , padding , x_delta_shape , preSpike , postSpike , weights ): ## action-backpropagating routine
153+ self .weights .set (weights )
154+ self .dWeights .set (dWeights )
155+
156+ @compilable
157+ def backtransmit (self ): ## action-backpropagating routine
164158 ## calc dInputs
165- dInputs = calc_dX_deconv (weights , postSpike , delta_shape = x_delta_shape ,
166- stride_size = stride , padding = padding )
167- return dInputs
168-
169- @ transition ( output_compartments = [ "inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" ] )
170- @ staticmethod
171- def reset ( in_shape , out_shape ):
172- preVals = jnp . zeros ( in_shape )
173- postVals = jnp .zeros (out_shape )
174- inputs = preVals
175- outputs = postVals
176- preSpike = preVals
177- postSpike = postVals
178- preTrace = preVals
179- postTrace = postVals
180- return inputs , outputs , preSpike , postSpike , preTrace , postTrace
159+ dInputs = calc_dX_deconv (
160+ self . weights . get (), self . postSpike . get (), delta_shape = self . x_delta_shape , stride_size = self . stride ,
161+ padding = self . padding
162+ )
163+ self . dInputs . set ( dInputs )
164+
165+ @ compilable
166+ def reset ( self ): # in_shape, out_shape):
167+ preVals = jnp .zeros (self . in_shape . get () )
168+ postVals = jnp . zeros ( self . out_shape . get ())
169+ self . inputs . set ( preVals )
170+ self . outputs . set ( postVals )
171+ self . preSpike . set ( preVals )
172+ self . postSpike . set ( postVals )
173+ self . preTrace . set ( preVals )
174+ self . postTrace . set ( postVals )
181175
182176 @classmethod
183177 def help (cls ): ## component help function
0 commit comments