Skip to content

Commit 6cb319a

Browse files
author
Alexander Ororbia
committed
refactored and passed test for deconv/stdp-deconv-syn and other minor cleanup for conv/deconv support
1 parent f72db76 commit 6cb319a

File tree

7 files changed

+120
-174
lines changed

7 files changed

+120
-174
lines changed

ngclearn/components/synapses/convolution/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .staticConvSynapse import StaticConvSynapse
33
from .deconvSynapse import DeconvSynapse
44
from .staticDeconvSynapse import StaticDeconvSynapse
5-
from .hebbianConvSynapse import HebbianConvSynapse
6-
from .hebbianDeconvSynapse import HebbianDeconvSynapse
5+
#from .hebbianConvSynapse import HebbianConvSynapse
6+
# from .hebbianDeconvSynapse import HebbianDeconvSynapse
77
from .traceSTDPConvSynapse import TraceSTDPConvSynapse
88
from .traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse

ngclearn/components/synapses/convolution/convSynapse.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,15 @@ def __init__(
9999
dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
100100
)
101101

102-
# @transition(output_compartments=["outputs"])
103-
# @staticmethod
104102
@compilable
105103
def advance_state(self): #Rscale, padding, stride, weights, biases, inputs):
106104
_x = self.inputs.get()
107105
## FIXME: does resist_scale affect update rules?
108-
outputs = conv2d(_x, self.weights.get(), stride_size=self.stride, padding=self.padding) * self.resist_scale + self.biases.get()
106+
outputs = conv2d(
107+
_x, self.weights.get(), stride_size=self.stride, padding=self.padding
108+
) * self.resist_scale + self.biases.get()
109109
self.outputs.set(outputs)
110110

111-
# @transition(output_compartments=["inputs", "outputs"])
112-
# @staticmethod
113111
@compilable
114112
def reset(self): #in_shape, out_shape):
115113
preVals = jnp.zeros(self.in_shape)

ngclearn/components/synapses/convolution/deconvSynapse.py

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngcsimlib.compilers.process import transition
4-
from ngcsimlib.component import Component
52
from ngcsimlib.compartment import Compartment
6-
3+
from ngcsimlib.parser import compilable
74
from ngclearn.utils.weight_distribution import initialize_params
85
from ngcsimlib.logger import info
9-
from ngclearn.utils import tensorstats
106
import ngclearn.utils.weight_distribution as dist
117
from ngclearn.components.synapses.convolution.ngcconv import deconv2d
128

9+
from ngclearn.components.jaxComponent import JaxComponent
10+
11+
1312
class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable
1413
"""
1514
A base deconvolutional (transposed convolutional) synaptic cable.
@@ -61,7 +60,7 @@ def __init__(
6160
self.shape = shape ## shape of synaptic filter tensor
6261
x_size, x_size = x_shape
6362
self.x_size = x_size
64-
self.Rscale = resist_scale ## post-transformation scale factor
63+
self.resist_scale = resist_scale ## post-transformation scale factor
6564
self.padding = padding
6665
self.stride = stride
6766

@@ -70,7 +69,7 @@ def __init__(
7069
self.pad_args = None
7170

7271
######################### set up compartments ##########################
73-
tmp_key, *subkeys = random.split(self.key.value, 4)
72+
tmp_key, *subkeys = random.split(self.key.get(), 4)
7473
weights = dist.initialize_params(subkeys[0], filter_init,
7574
shape) ## filter tensor
7675
self.batch_size = batch_size # 1
@@ -89,36 +88,35 @@ def __init__(
8988
(1, shape[1]))
9089
if bias_init else 0.0)
9190

92-
@transition(output_compartments=["outputs"])
93-
@staticmethod
94-
def advance_state(Rscale, padding, stride, weights, biases, inputs):
95-
_x = inputs
96-
out = deconv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases
97-
return out
98-
99-
@transition(output_compartments=["inputs", "outputs"])
100-
@staticmethod
101-
def reset(in_shape, out_shape):
102-
preVals = jnp.zeros(in_shape)
103-
postVals = jnp.zeros(out_shape)
104-
inputs = preVals
105-
outputs = postVals
106-
return inputs, outputs
107-
108-
def save(self, directory, **kwargs):
109-
file_name = directory + "/" + self.name + ".npz"
110-
if self.bias_init != None:
111-
jnp.savez(file_name, weights=self.weights.value,
112-
biases=self.biases.value)
113-
else:
114-
jnp.savez(file_name, weights=self.weights.value)
115-
116-
def load(self, directory, **kwargs):
117-
file_name = directory + "/" + self.name + ".npz"
118-
data = jnp.load(file_name)
119-
self.weights.set(data['weights'])
120-
if "biases" in data.keys():
121-
self.biases.set(data['biases'])
91+
@compilable
92+
def advance_state(self):
93+
_x = self.inputs.get()
94+
out = deconv2d(
95+
_x, self.weights.get(), stride_size=self.stride, padding=self.padding
96+
) * self.resist_scale + self.biases.get()
97+
self.outputs.set(out)
98+
99+
@compilable
100+
def reset(self): #in_shape, out_shape):
101+
preVals = jnp.zeros(self.in_shape)
102+
postVals = jnp.zeros(self.out_shape)
103+
self.inputs.set(preVals)
104+
self.outputs.set(postVals)
105+
106+
# def save(self, directory, **kwargs):
107+
# file_name = directory + "/" + self.name + ".npz"
108+
# if self.bias_init != None:
109+
# jnp.savez(file_name, weights=self.weights.get(),
110+
# biases=self.biases.get())
111+
# else:
112+
# jnp.savez(file_name, weights=self.weights.get())
113+
#
114+
# def load(self, directory, **kwargs):
115+
# file_name = directory + "/" + self.name + ".npz"
116+
# data = jnp.load(file_name)
117+
# self.weights.set(data['weights'])
118+
# if "biases" in data.keys():
119+
# self.biases.set(data['biases'])
122120

123121
@classmethod
124122
def help(cls): ## component help function
@@ -151,17 +149,3 @@ def help(cls): ## component help function
151149
"dynamics": "outputs = [K @.T inputs] * R + b",
152150
"hyperparameters": hyperparams}
153151
return info
154-
155-
def __repr__(self):
156-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
157-
maxlen = max(len(c) for c in comps) + 5
158-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
159-
for c in comps:
160-
stats = tensorstats(getattr(self, c).value)
161-
if stats is not None:
162-
line = [f"{k}: {v}" for k, v in stats.items()]
163-
line = ", ".join(line)
164-
else:
165-
line = "None"
166-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
167-
return lines

ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,9 @@ def _compute_update(self): #pretrace_target, Aplus, Aminus, stride, pad_args, de
146146
dWeights = (dW_ltp + dW_ltd)
147147
return dWeights
148148

149-
# @transition(output_compartments=["weights", "dWeights"])
150-
# @staticmethod
151149
@compilable
152150
def evolve(self):
153-
# pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace,
154-
# postSpike, postTrace, weights, eta
155-
156151
dWeights = self._compute_update()
157-
# dWeights = TraceSTDPConvSynapse._compute_update(
158-
# pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
159-
# )
160152
if self.w_decay > 0.: ## apply synaptic decay
161153
weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay ## conduct decayed STDP-ascent
162154
else:
@@ -169,10 +161,8 @@ def evolve(self):
169161
self.weights.set(weights)
170162
self.dWeights.set(dWeights)
171163

172-
# @transition(output_compartments=["dInputs"])
173-
# @staticmethod
174164
@compilable
175-
def backtransmit(self): # x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights
165+
def backtransmit(self):
176166
## action-backpropagating routine
177167
## calc dInputs - adjustment w.r.t. input signal
178168
k_size, k_size, n_in_chan, n_out_chan = self.shape

ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
42
from ngcsimlib.compartment import Compartment
5-
6-
from .deconvSynapse import DeconvSynapse
3+
from ngcsimlib.parser import compilable
74
from ngclearn.utils.weight_distribution import initialize_params
8-
from ngcsimlib.logger import info
9-
from ngclearn.utils import tensorstats
105
import ngclearn.utils.weight_distribution as dist
6+
7+
from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse
8+
119
from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv,
1210
_calc_dK_deconv, calc_dX_deconv,
1311
calc_dK_deconv)
14-
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
1512

1613
class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional cable
1714
"""
@@ -92,7 +89,7 @@ def __init__(
9289

9390
######################### set up compartments ##########################
9491
## Compartment setup and shape computation
95-
self.dWeights = Compartment(self.weights.value * 0)
92+
self.dWeights = Compartment(self.weights.get() * 0)
9693
self.dInputs = Compartment(jnp.zeros(self.in_shape))
9794
self.preSpike = Compartment(jnp.zeros(self.in_shape))
9895
self.preTrace = Compartment(jnp.zeros(self.in_shape))
@@ -108,76 +105,73 @@ def __init__(
108105
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
109106
k_size, k_size, n_in_chan, n_out_chan = shape
110107
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
111-
_d = deconv2d(_x, self.weights.value, stride_size=self.stride,
108+
_d = deconv2d(_x, self.weights.get(), stride_size=self.stride,
112109
padding=self.padding) * 0
113110
_dK = _calc_dK_deconv(_x, _d, stride_size=self.stride, out_size=k_size)
114111
## get filter update correction
115-
dx = _dK.shape[0] - self.weights.value.shape[0]
116-
dy = _dK.shape[1] - self.weights.value.shape[1]
112+
dx = _dK.shape[0] - self.weights.get().shape[0]
113+
dy = _dK.shape[1] - self.weights.get().shape[1]
117114
self.delta_shape = (abs(dx), abs(dy))
118115

119116
## get input update correction
120-
_dx = _calc_dX_deconv(self.weights.value, _d, stride_size=self.stride,
117+
_dx = _calc_dX_deconv(self.weights.get(), _d, stride_size=self.stride,
121118
padding=self.padding)
122119
dx = (_dx.shape[1] - _x.shape[1]) # abs()
123120
dy = (_dx.shape[2] - _x.shape[2])
124121
self.x_delta_shape = (dx, dy)
125122

126-
@staticmethod
127-
def _compute_update(
128-
pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape, preSpike, preTrace, postSpike, postTrace
129-
):
130-
k_size, k_size, n_in_chan, n_out_chan = shape
123+
def _compute_update(self):
124+
k_size, k_size, n_in_chan, n_out_chan = self.shape
131125
## calc dFilters
132-
dW_ltp = calc_dK_deconv(preTrace - pretrace_target, postSpike * Aplus,
133-
delta_shape=delta_shape, stride_size=stride,
134-
out_size=k_size, padding=padding)
135-
dW_ltd = -calc_dK_deconv(preSpike, postTrace * Aminus,
136-
delta_shape=delta_shape, stride_size=stride,
137-
out_size=k_size, padding=padding)
126+
dW_ltp = calc_dK_deconv(
127+
self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus,
128+
delta_shape=self.delta_shape, stride_size=self.stride, out_size=k_size, padding=self.padding
129+
)
130+
dW_ltd = -calc_dK_deconv(
131+
self.preSpike.get(), self.postTrace.get() * self.Aminus, delta_shape=self.delta_shape,
132+
stride_size=self.stride, out_size=k_size, padding=self.padding
133+
)
138134
dWeights = (dW_ltp + dW_ltd)
139135
return dWeights
140136

141-
@transition(output_compartments=["weights", "dWeights"])
142-
@staticmethod
143-
def evolve(
144-
pretrace_target, Aplus, Aminus, w_decay, w_bound, shape, stride, padding, delta_shape, preSpike, preTrace,
145-
postSpike, postTrace, weights, eta
146-
):
147-
dWeights = TraceSTDPDeconvSynapse._compute_update(
148-
pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape,
149-
preSpike, preTrace, postSpike, postTrace
150-
)
151-
if w_decay > 0.: ## apply synaptic decay
152-
weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
137+
@compilable
138+
def evolve(self):
139+
dWeights = self._compute_update()
140+
# dWeights = TraceSTDPDeconvSynapse._compute_update(
141+
# pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape,
142+
# preSpike, preTrace, postSpike, postTrace
143+
# )
144+
if self.w_decay > 0.: ## apply synaptic decay and conduct decayed STDP-ascent
145+
weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay
153146
else:
154-
weights = weights + dWeights * eta ## conduct STDP-ascent
147+
weights = self.weights.get() + dWeights * self.eta ## conduct STDP-ascent
155148
## Apply any enforced filter constraints
156-
if w_bound > 0.: ## enforce non-negativity
149+
if self.w_bound > 0.: ## enforce non-negativity
157150
eps = 0.01 # 0.001
158-
weights = jnp.clip(weights, eps, w_bound - eps)
159-
return weights, dWeights
151+
weights = jnp.clip(weights, eps, self.w_bound - eps)
160152

161-
@transition(output_compartments=["dInputs"])
162-
@staticmethod
163-
def backtransmit(stride, padding, x_delta_shape, preSpike, postSpike, weights): ## action-backpropagating routine
153+
self.weights.set(weights)
154+
self.dWeights.set(dWeights)
155+
156+
@compilable
157+
def backtransmit(self): ## action-backpropagating routine
164158
## calc dInputs
165-
dInputs = calc_dX_deconv(weights, postSpike, delta_shape=x_delta_shape,
166-
stride_size=stride, padding=padding)
167-
return dInputs
168-
169-
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"])
170-
@staticmethod
171-
def reset(in_shape, out_shape):
172-
preVals = jnp.zeros(in_shape)
173-
postVals = jnp.zeros(out_shape)
174-
inputs = preVals
175-
outputs = postVals
176-
preSpike = preVals
177-
postSpike = postVals
178-
preTrace = preVals
179-
postTrace = postVals
180-
return inputs, outputs, preSpike, postSpike, preTrace, postTrace
159+
dInputs = calc_dX_deconv(
160+
self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride,
161+
padding=self.padding
162+
)
163+
self.dInputs.set(dInputs)
164+
165+
@compilable
166+
def reset(self): # in_shape, out_shape):
167+
preVals = jnp.zeros(self.in_shape.get())
168+
postVals = jnp.zeros(self.out_shape.get())
169+
self.inputs.set(preVals)
170+
self.outputs.set(postVals)
171+
self.preSpike.set(preVals)
172+
self.postSpike.set(postVals)
173+
self.preTrace.set(preVals)
174+
self.postTrace.set(postVals)
181175

182176
@classmethod
183177
def help(cls): ## component help function

tests/components/synapses/convolution/test_traceSTDPConvSynapse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def test_TraceSTDPConvSynapse1():
5454
[[1.], [0.]]]]
5555
)
5656

57-
reset_process.run() #
57+
reset_process.run() # ctx.reset()
5858
a.inputs.set(x)
5959
advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
6060
y = (a.outputs.get() > 0.) * 1. ## fake out post-syn spikes
6161
assert_array_equal(y, y_truth)
62-
print(y)
63-
print("y.Tr:\n", y_truth)
64-
print("======")
62+
# print(y)
63+
# print("y.Tr:\n", y_truth)
64+
# print("======")
6565

6666
# print("NGC-Learn.shape = ", node.outputs.get().shape)
6767
a.preSpike.set(x)

0 commit comments

Comments
 (0)