Skip to content

Commit f72db76

Browse files
author
Alexander Ororbia
committed
refactored stdp-conv-syn/conv-syn and test passed
1 parent 94477b8 commit f72db76

File tree

3 files changed

+120
-148
lines changed

3 files changed

+120
-148
lines changed

ngclearn/components/synapses/convolution/convSynapse.py

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngcsimlib.compilers.process import transition
4-
from ngcsimlib.component import Component
52
from ngcsimlib.compartment import Compartment
6-
3+
from ngcsimlib.parser import compilable
74
from ngclearn.utils.weight_distribution import initialize_params
85
from ngcsimlib.logger import info
9-
from ngclearn.utils import tensorstats
106
import ngclearn.utils.weight_distribution as dist
117
from ngclearn.components.synapses.convolution.ngcconv import conv2d
128

9+
from ngclearn.components.jaxComponent import JaxComponent
10+
1311
class ConvSynapse(JaxComponent): ## base-level convolutional cable
1412
"""
1513
A base convolutional synaptic cable.
@@ -61,15 +59,15 @@ def __init__(
6159
self.shape = shape ## shape of synaptic filter tensor
6260
x_size, x_size = x_shape
6361
self.x_size = x_size
64-
self.Rscale = resist_scale ## post-transformation scale factor
62+
self.resist_scale = resist_scale ## post-transformation scale factor
6563
self.padding = padding
6664
self.stride = stride
6765

6866
####################### Set up padding arguments #######################
6967
k_size, k_size, n_in_chan, n_out_chan = shape
7068
self.pad_args = None
7169
if self.padding is not None and self.padding == "SAME":
72-
if (x_size % stride == 0):
70+
if x_size % stride == 0:
7371
pad_along_height = max(k_size - stride, 0)
7472
else:
7573
pad_along_height = max(k_size - (x_size % stride), 0)
@@ -83,7 +81,7 @@ def __init__(
8381
self.pad_args = ((0, 0), (0, 0))
8482

8583
######################### set up compartments ##########################
86-
tmp_key, *subkeys = random.split(self.key.value, 4)
84+
tmp_key, *subkeys = random.split(self.key.get(), 4)
8785
weights = dist.initialize_params(subkeys[0], filter_init, shape) ## filter tensor
8886
self.batch_size = batch_size # 1
8987
## Compartment setup and shape computation
@@ -101,36 +99,38 @@ def __init__(
10199
dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
102100
)
103101

104-
@transition(output_compartments=["outputs"])
105-
@staticmethod
106-
def advance_state(Rscale, padding, stride, weights, biases, inputs):
107-
_x = inputs
108-
outputs = conv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases
109-
return outputs
110-
111-
@transition(output_compartments=["inputs", "outputs"])
112-
@staticmethod
113-
def reset(in_shape, out_shape):
114-
preVals = jnp.zeros(in_shape)
115-
postVals = jnp.zeros(out_shape)
116-
inputs = preVals
117-
outputs = postVals
118-
return inputs, outputs
119-
120-
def save(self, directory, **kwargs):
121-
file_name = directory + "/" + self.name + ".npz"
122-
if self.bias_init != None:
123-
jnp.savez(file_name, weights=self.weights.value,
124-
biases=self.biases.value)
125-
else:
126-
jnp.savez(file_name, weights=self.weights.value)
127-
128-
def load(self, directory, **kwargs):
129-
file_name = directory + "/" + self.name + ".npz"
130-
data = jnp.load(file_name)
131-
self.weights.set(data['weights'])
132-
if "biases" in data.keys():
133-
self.biases.set(data['biases'])
102+
# @transition(output_compartments=["outputs"])
103+
# @staticmethod
104+
@compilable
105+
def advance_state(self): #Rscale, padding, stride, weights, biases, inputs):
106+
_x = self.inputs.get()
107+
## FIXME: does resist_scale affect update rules?
108+
outputs = conv2d(_x, self.weights.get(), stride_size=self.stride, padding=self.padding) * self.resist_scale + self.biases.get()
109+
self.outputs.set(outputs)
110+
111+
# @transition(output_compartments=["inputs", "outputs"])
112+
# @staticmethod
113+
@compilable
114+
def reset(self): #in_shape, out_shape):
115+
preVals = jnp.zeros(self.in_shape)
116+
postVals = jnp.zeros(self.out_shape)
117+
self.inputs.set(preVals)
118+
self.outputs.set(postVals)
119+
120+
# def save(self, directory, **kwargs):
121+
# file_name = directory + "/" + self.name + ".npz"
122+
# if self.bias_init != None:
123+
# jnp.savez(file_name, weights=self.weights.get(),
124+
# biases=self.biases.get())
125+
# else:
126+
# jnp.savez(file_name, weights=self.weights.get())
127+
#
128+
# def load(self, directory, **kwargs):
129+
# file_name = directory + "/" + self.name + ".npz"
130+
# data = jnp.load(file_name)
131+
# self.weights.set(data['weights'])
132+
# if "biases" in data.keys():
133+
# self.biases.set(data['biases'])
134134

135135
@classmethod
136136
def help(cls): ## component help function
@@ -163,17 +163,3 @@ def help(cls): ## component help function
163163
"dynamics": "outputs = [K @ inputs] * R + b",
164164
"hyperparameters": hyperparams}
165165
return info
166-
167-
def __repr__(self):
168-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
169-
maxlen = max(len(c) for c in comps) + 5
170-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
171-
for c in comps:
172-
stats = tensorstats(getattr(self, c).value)
173-
if stats is not None:
174-
line = [f"{k}: {v}" for k, v in stats.items()]
175-
line = ", ".join(line)
176-
else:
177-
line = "None"
178-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
179-
return lines

ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
42
from ngcsimlib.compartment import Compartment
5-
6-
from .convSynapse import ConvSynapse
3+
from ngcsimlib.parser import compilable
74
from ngclearn.utils.weight_distribution import initialize_params
8-
from ngcsimlib.logger import info
9-
from ngclearn.utils import tensorstats
105
import ngclearn.utils.weight_distribution as dist
6+
7+
from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse
8+
119
from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding,
1210
_conv_valid_transpose_padding)
1311
from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv,
@@ -93,7 +91,7 @@ def __init__(
9391

9492
######################### set up compartments ##########################
9593
## Compartment setup and shape computation
96-
self.dWeights = Compartment(self.weights.value * 0)
94+
self.dWeights = Compartment(self.weights.get() * 0)
9795
self.dInputs = Compartment(jnp.zeros(self.in_shape))
9896
self.preSpike = Compartment(jnp.zeros(self.in_shape))
9997
self.preTrace = Compartment(jnp.zeros(self.in_shape))
@@ -108,94 +106,99 @@ def __init__(
108106
k_size, k_size, n_in_chan, n_out_chan = self.shape
109107
if padding == "SAME":
110108
self.antiPad = _conv_same_transpose_padding(
111-
self.postSpike.value.shape[1],
109+
self.postSpike.get().shape[1],
112110
self.x_size, k_size, stride)
113111
elif padding == "VALID":
114112
self.antiPad = _conv_valid_transpose_padding(
115-
self.postSpike.value.shape[1],
113+
self.postSpike.get().shape[1],
116114
self.x_size, k_size, stride)
117115
########################################################################
118116

119117
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
120118
k_size, k_size, n_in_chan, n_out_chan = shape
121119
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
122-
_d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0
120+
_d = conv2d(_x, weights.get(), stride_size=stride, padding=padding) * 0
123121
_dK = _calc_dK_conv(_x, _d, stride_size=stride, padding=pad_args)
124122
## get filter update correction
125-
dx = _dK.shape[0] - weights.value.shape[0]
126-
dy = _dK.shape[1] - weights.value.shape[1]
123+
dx = _dK.shape[0] - weights.get().shape[0]
124+
dy = _dK.shape[1] - weights.get().shape[1]
127125
#self.delta_shape = (dx, dy)
128126
self.delta_shape = (max(dx, 0), max(dy, 0))
129127
## get input update correction
130-
_dx = _calc_dX_conv(weights.value, _d, stride_size=stride,
128+
_dx = _calc_dX_conv(weights.get(), _d, stride_size=stride,
131129
anti_padding=pad_args)
132130
dx = (_dx.shape[1] - _x.shape[1])
133131
dy = (_dx.shape[2] - _x.shape[2])
134132
self.x_delta_shape = (dx, dy)
135133

136-
@staticmethod
137-
def _compute_update(
138-
pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
139-
):
134+
#@staticmethod
135+
def _compute_update(self): #pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
140136
## Compute long-term potentiation to filters
141137
dW_ltp = calc_dK_conv(
142-
preTrace - pretrace_target, postSpike * Aplus, delta_shape=delta_shape, stride_size=stride, padding=pad_args
138+
self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus, delta_shape=self.delta_shape,
139+
stride_size=self.stride, padding=self.pad_args
143140
)
144141
## Compute long-term depression to filters
145142
dW_ltd = -calc_dK_conv(
146-
preSpike, postTrace * Aminus, delta_shape=delta_shape, stride_size=stride, padding=pad_args
143+
self.preSpike.get(), self.postTrace.get() * self.Aminus, delta_shape=self.delta_shape,
144+
stride_size=self.stride, padding=self.pad_args
147145
)
148146
dWeights = (dW_ltp + dW_ltd)
149147
return dWeights
150148

151-
@transition(output_compartments=["weights", "dWeights"])
152-
@staticmethod
153-
def evolve(
154-
pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace,
155-
postSpike, postTrace, weights, eta
156-
):
157-
dWeights = TraceSTDPConvSynapse._compute_update(
158-
pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
159-
)
160-
if w_decay > 0.: ## apply synaptic decay
161-
weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
149+
# @transition(output_compartments=["weights", "dWeights"])
150+
# @staticmethod
151+
@compilable
152+
def evolve(self):
153+
# pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace,
154+
# postSpike, postTrace, weights, eta
155+
156+
dWeights = self._compute_update()
157+
# dWeights = TraceSTDPConvSynapse._compute_update(
158+
# pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
159+
# )
160+
if self.w_decay > 0.: ## apply synaptic decay
161+
weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay ## conduct decayed STDP-ascent
162162
else:
163-
weights = weights + dWeights * eta ## conduct STDP-ascent
163+
weights = self.weights.get() + dWeights * self.eta ## conduct STDP-ascent
164164
## Apply any enforced filter constraints
165-
if w_bound > 0.: ## enforce non-negativity
165+
if self.w_bound > 0.: ## enforce non-negativity
166166
eps = 0.01 # 0.001
167-
weights = jnp.clip(weights, eps, w_bound - eps)
168-
return weights, dWeights
169-
170-
@transition(output_compartments=["dInputs"])
171-
@staticmethod
172-
def backtransmit(
173-
x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights
174-
): ## action-backpropagating routine
167+
weights = jnp.clip(weights, eps, self.w_bound - eps)
168+
169+
self.weights.set(weights)
170+
self.dWeights.set(dWeights)
171+
172+
# @transition(output_compartments=["dInputs"])
173+
# @staticmethod
174+
@compilable
175+
def backtransmit(self): # x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights
176+
## action-backpropagating routine
175177
## calc dInputs - adjustment w.r.t. input signal
176-
k_size, k_size, n_in_chan, n_out_chan = shape
178+
k_size, k_size, n_in_chan, n_out_chan = self.shape
177179
# antiPad = None
178180
# if padding == "SAME":
179181
# antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size,
180182
# k_size, stride)
181183
# elif padding == "VALID":
182184
# antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
183185
# k_size, stride)
184-
dInputs = calc_dX_conv(weights, postSpike, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad)
185-
return dInputs
186-
187-
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"])
188-
@staticmethod
189-
def reset(in_shape, out_shape):
190-
preVals = jnp.zeros(in_shape)
191-
postVals = jnp.zeros(out_shape)
192-
inputs = preVals
193-
outputs = postVals
194-
preSpike = preVals
195-
postSpike = postVals
196-
preTrace = preVals
197-
postTrace = postVals
198-
return inputs, outputs, preSpike, postSpike, preTrace, postTrace
186+
dInputs = calc_dX_conv(
187+
self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride,
188+
anti_padding=self.antiPad
189+
)
190+
self.dInputs.set(dInputs)
191+
192+
@compilable
193+
def reset(self): # in_shape, out_shape):
194+
preVals = jnp.zeros(self.in_shape.get())
195+
postVals = jnp.zeros(self.out_shape.get())
196+
self.inputs.set(preVals)
197+
self.outputs.set(postVals)
198+
self.preSpike.set(preVals)
199+
self.postSpike.set(postVals)
200+
self.preTrace.set(preVals)
201+
self.postTrace.set(postVals)
199202

200203
@classmethod
201204
def help(cls): ## component help function

0 commit comments

Comments
 (0)