Skip to content

Commit 10ef0e0

Browse files
rxng8ago109
andauthored
Refactoring neuronal and synaptic components (#123) - merge from fork to v3
* refactoring graded cells * update refactored models * update sLIF cell --------- Co-authored-by: Alex Ororbia <agocse109@gmail.com>
1 parent 6cb319a commit 10ef0e0

File tree

11 files changed

+316
-271
lines changed

11 files changed

+316
-271
lines changed

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from ngclearn import resolver, Component, Compartment
1+
# %%
2+
23
from ngclearn.components.jaxComponent import JaxComponent
34
from jax import numpy as jnp, jit
45
from ngclearn.utils import tensorstats
56
from ngclearn.utils.model_utils import sigmoid, d_sigmoid
6-
from ngcsimlib.compilers.process import transition
7+
8+
from ngcsimlib.logger import info
9+
from ngcsimlib.compartment import Compartment
10+
from ngcsimlib.parser import compilable
711

812
class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
913
"""
@@ -59,14 +63,20 @@ def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None,
5963
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
6064
self.mask = Compartment(restVals + 1.0)
6165

62-
@transition(output_compartments=["dp", "dtarget", "L", "mask"])
63-
@staticmethod
64-
def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bernoulli error cell output
66+
# @transition(output_compartments=["dp", "dtarget", "L", "mask"])
67+
@compilable
68+
def advance_state(self, dt): ## compute Bernoulli error cell output
69+
# Get the variables
70+
p = self.p.get()
71+
target = self.target.get()
72+
modulator = self.modulator.get()
73+
mask = self.mask.get()
74+
6575
# Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
6676
# behavior of the local cost functional
6777
eps = 0.0001
6878
_p = p
69-
if input_logits: ## convert from "logits" to probs via sigmoidal link function
79+
if self.input_logits: ## convert from "logits" to probs via sigmoidal link function
7080
_p = sigmoid(p)
7181
_p = jnp.clip(_p, eps, 1. - eps) ## post-process to prevent div by 0
7282
x = target
@@ -78,7 +88,7 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern
7888
log_p = jnp.log(_p) ## ln(p)
7989
log_one_min_p = jnp.log(one_min_p) ## ln(1 - p)
8090
L = jnp.sum(log_p * x + log_one_min_p * one_min_x) ## Bern LL
81-
if input_logits:
91+
if self.input_logits:
8292
dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p)
8393
else:
8494
dL_dp = x/(_p) - one_min_x/one_min_p ## d(Bern LL)/dp
@@ -89,14 +99,21 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern
8999
dp = dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
90100
dtarget = dL_dx * modulator * mask
91101
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
92-
return dp, dtarget, jnp.squeeze(L), mask
93-
94-
@transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
95-
@staticmethod
96-
def reset(batch_size, shape): ## reset core components/statistics
97-
_shape = (batch_size, shape[0])
98-
if len(shape) > 1:
99-
_shape = (batch_size, shape[0], shape[1], shape[2])
102+
103+
# Set state
104+
# dp, dtarget, jnp.squeeze(L), mask
105+
self.dp.set(dp)
106+
self.dtarget.set(dtarget)
107+
self.L.set(jnp.squeeze(L))
108+
self.mask.set(mask)
109+
110+
111+
# @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
112+
@compilable
113+
def reset(self, batch_size): ## reset core components/statistics
114+
_shape = (batch_size, self.shape[0])
115+
if len(self.shape) > 1:
116+
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
100117
restVals = jnp.zeros(_shape) ## "rest"/reset values
101118
dp = restVals
102119
dtarget = restVals
@@ -105,7 +122,16 @@ def reset(batch_size, shape): ## reset core components/statistics
105122
modulator = restVals + 1. ## reset modulator signal
106123
L = 0. #jnp.zeros((1, 1)) ## rest loss
107124
mask = jnp.ones(_shape) ## reset mask
108-
return dp, dtarget, target, p, modulator, L, mask
125+
126+
# Set compartment
127+
self.dp.set(dp)
128+
self.dtarget.set(dtarget)
129+
self.target.set(target)
130+
self.p.set(p)
131+
self.modulator.set(modulator)
132+
self.L.set(L)
133+
self.mask.set(mask)
134+
109135

110136
@classmethod
111137
def help(cls): ## component help function
@@ -136,11 +162,11 @@ def help(cls): ## component help function
136162
return info
137163

138164
def __repr__(self):
139-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
165+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
140166
maxlen = max(len(c) for c in comps) + 5
141167
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
142168
for c in comps:
143-
stats = tensorstats(getattr(self, c).value)
169+
stats = tensorstats(getattr(self, c).get())
144170
if stats is not None:
145171
line = [f"{k}: {v}" for k, v in stats.items()]
146172
line = ", ".join(line)

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from ngclearn import resolver, Component, Compartment
1+
# %%
2+
23
from ngclearn.components.jaxComponent import JaxComponent
34
from jax import numpy as jnp, jit
45
from ngclearn.utils import tensorstats
5-
from ngcsimlib.compilers.process import transition
6+
7+
from ngcsimlib.logger import info
8+
from ngcsimlib.compartment import Compartment
9+
from ngcsimlib.parser import compilable
610

711
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
812
"""
@@ -71,9 +75,15 @@ def eval_log_density(target, mu, Sigma):
7175
log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
7276
return log_density
7377

74-
@transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
75-
@staticmethod
76-
def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
78+
@compilable
79+
def advance_state(self, dt): ## compute Gaussian error cell output
80+
# Get the variables
81+
mu = self.mu.get()
82+
target = self.target.get()
83+
Sigma = self.Sigma.get()
84+
modulator = self.modulator.get()
85+
mask = self.mask.get()
86+
7787
# Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
7888
# behavior of the local cost functional:
7989
# FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
@@ -90,24 +100,13 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e
90100
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
91101
dtarget = dtarget * modulator * mask
92102
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
93-
return dmu, dtarget, dSigma, jnp.squeeze(L), mask
94103

95-
@transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
96-
@staticmethod
97-
def reset(batch_size, shape, sigma_shape): ## reset core components/statistics
98-
_shape = (batch_size, shape[0])
99-
if len(shape) > 1:
100-
_shape = (batch_size, shape[0], shape[1], shape[2])
101-
restVals = jnp.zeros(_shape)
102-
dmu = restVals
103-
dtarget = restVals
104-
dSigma = jnp.zeros(sigma_shape)
105-
target = restVals
106-
mu = restVals
107-
modulator = mu + 1.
108-
L = 0. #jnp.zeros((1, 1))
109-
mask = jnp.ones(_shape)
110-
return dmu, dtarget, dSigma, target, mu, modulator, L, mask
104+
# Update compartments
105+
self.dmu.set(dmu)
106+
self.dtarget.set(dtarget)
107+
self.dSigma.set(dSigma)
108+
self.L.set(jnp.squeeze(L))
109+
self.mask.set(mask)
111110

112111
@classmethod
113112
def help(cls): ## component help function
@@ -140,11 +139,11 @@ def help(cls): ## component help function
140139
return info
141140

142141
def __repr__(self):
143-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
142+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
144143
maxlen = max(len(c) for c in comps) + 5
145144
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
146145
for c in comps:
147-
stats = tensorstats(getattr(self, c).value)
146+
stats = tensorstats(getattr(self, c).get())
148147
if stats is not None:
149148
line = [f"{k}: {v}" for k, v in stats.items()]
150149
line = ", ".join(line)

ngclearn/components/neurons/graded/laplacianErrorCell.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from ngclearn import resolver, Component, Compartment
1+
# %%
2+
23
from ngclearn.components.jaxComponent import JaxComponent
34
from jax import numpy as jnp, jit
45
from ngclearn.utils import tensorstats
5-
from ngcsimlib.compilers.process import transition
6+
7+
from ngcsimlib.logger import info
8+
from ngcsimlib.compartment import Compartment
9+
from ngcsimlib.parser import compilable
610

711
class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
812
"""
@@ -44,7 +48,7 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
4448
else:
4549
_shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
4650
scale_shape = (1, 1)
47-
if not isinstance(scale, float) and not isinstance(sigma, int):
51+
if not isinstance(scale, float) and not isinstance(scale, int):
4852
scale_shape = jnp.array(scale).shape
4953
self.scale_shape = scale_shape
5054
## Layer Size setup
@@ -67,9 +71,15 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
6771
self.modulator = Compartment(restVals + 1.0) ## to be set/consumed
6872
self.mask = Compartment(restVals + 1.0)
6973

70-
@transition(output_compartments=["dshift", "dtarget", "dScale", "L", "mask"])
71-
@staticmethod
72-
def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplacian error cell output
74+
@compilable
75+
def advance_state(self, dt): ## compute Laplacian error cell output
76+
# Get the variables
77+
shift = self.shift.get()
78+
target = self.target.get()
79+
Scale = self.Scale.get()
80+
modulator = self.modulator.get()
81+
mask = self.mask.get()
82+
7383
# Moves Laplacian cell dynamics one step forward. Specifically, this routine emulates the error unit
7484
# behavior of the local cost functional:
7585
# FIXME: Currently, below does: L(targ, shift) = -||targ - shift||_1/scale
@@ -85,21 +95,13 @@ def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplaci
8595
dshift = dshift * modulator * mask
8696
dtarget = dtarget * modulator * mask
8797
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
88-
return dshift, dtarget, dScale, jnp.squeeze(L), mask
89-
90-
@transition(output_compartments=["dshift", "dtarget", "dScale", "target", "shift", "modulator", "L", "mask"])
91-
@staticmethod
92-
def reset(batch_size, n_units, scale_shape):
93-
restVals = jnp.zeros((batch_size, n_units))
94-
dshift = restVals
95-
dtarget = restVals
96-
dScale = jnp.zeros(scale_shape)
97-
target = restVals
98-
shift = restVals
99-
modulator = shift + 1.
100-
L = 0.
101-
mask = jnp.ones((batch_size, n_units))
102-
return dshift, dtarget, dScale, target, shift, modulator, L, mask
98+
99+
# Update compartments
100+
self.dshift.set(dshift)
101+
self.dtarget.set(dtarget)
102+
self.dScale.set(dScale)
103+
self.L.set(jnp.squeeze(L))
104+
self.mask.set(mask)
103105

104106
@classmethod
105107
def help(cls): ## component help function
@@ -131,11 +133,11 @@ def help(cls): ## component help function
131133
return info
132134

133135
def __repr__(self):
134-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
136+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
135137
maxlen = max(len(c) for c in comps) + 5
136138
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
137139
for c in comps:
138-
stats = tensorstats(getattr(self, c).value)
140+
stats = tensorstats(getattr(self, c).get())
139141
if stats is not None:
140142
line = [f"{k}: {v}" for k, v in stats.items()]
141143
line = ", ".join(line)

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
1212
step_euler, step_rk2, step_rk4
1313

14+
from ngcsimlib.logger import info
15+
from ngcsimlib.parser import compilable
16+
1417
def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
1518
z_leak = jnp.sign(z) ## d/dx of Laplace is signum
1619
dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
@@ -198,7 +201,6 @@ def __init__(
198201
self.n_units = n_units
199202
self.batch_size = batch_size
200203

201-
202204
omega_0 = None
203205
if act_fx == "sine":
204206
omega_0 = kwargs["omega_0"]
@@ -211,46 +213,43 @@ def __init__(
211213
self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
212214
self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
213215

214-
@transition(output_compartments=["j", "j_td", "z", "zF"])
215-
@staticmethod
216-
def advance_state(
217-
dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, resist_scale, thresholdType, thr_lmbda, is_stateful,
218-
output_scale, j, j_td, z):
216+
@compilable
217+
def advance_state(self, dt):
218+
# Get the compartment values
219+
j = self.j.get()
220+
j_td = self.j_td.get()
221+
z = self.z.get()
222+
219223
#if tau_m > 0.:
220-
if is_stateful:
224+
if self.is_stateful:
221225
### run a step of integration over neuronal dynamics
222226
## Notes:
223227
## self.pressure <-- "top-down" expectation / contextual pressure
224228
## self.current <-- "bottom-up" data-dependent signal
225-
dfx_val = dfx(z)
229+
dfx_val = self.dfx(z)
226230
j = _modulate(j, dfx_val)
227-
j = j * resist_scale
231+
j = j * self.resist_scale
228232
tmp_z = _run_cell(dt, j, j_td, z,
229-
tau_m, leak_gamma=priorLeakRate,
230-
integType=intgFlag, priorType=priorType)
233+
self.tau_m, leak_gamma=self.priorLeakRate,
234+
integType=self.intgFlag, priorType=self.priorType)
231235
## apply optional thresholding sub-dynamics
232-
if thresholdType == "soft_threshold":
233-
tmp_z = threshold_soft(tmp_z, thr_lmbda)
234-
elif thresholdType == "cauchy_threshold":
235-
tmp_z = threshold_cauchy(tmp_z, thr_lmbda)
236+
if self.thresholdType == "soft_threshold":
237+
tmp_z = threshold_soft(tmp_z, self.thr_lmbda)
238+
elif self.thresholdType == "cauchy_threshold":
239+
tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda)
236240
z = tmp_z ## pre-activation function value(s)
237-
zF = fx(z) * output_scale ## post-activation function value(s)
241+
zF = self.fx(z) * self.output_scale ## post-activation function value(s)
238242
else:
239243
## run in "stateless" mode (when no membrane time constant provided)
240244
j_total = j + j_td
241245
z = _run_cell_stateless(j_total)
242-
zF = fx(z) * output_scale
243-
return j, j_td, z, zF
244-
245-
@transition(output_compartments=["j", "j_td", "z", "zF"])
246-
@staticmethod
247-
def reset(batch_size, shape): #n_units
248-
_shape = (batch_size, shape[0])
249-
if len(shape) > 1:
250-
_shape = (batch_size, shape[0], shape[1], shape[2])
251-
restVals = jnp.zeros(_shape)
252-
return tuple([restVals for _ in range(4)])
246+
zF = self.fx(z) * self.output_scale
253247

248+
# Update compartments
249+
self.j.set(j)
250+
self.j_td.set(j_td)
251+
self.z.set(z)
252+
self.zF.set(zF)
254253

255254
def save(self, directory, **kwargs):
256255
## do a protected save of constants, depending on whether they are floats or arrays
@@ -308,11 +307,11 @@ def help(cls): ## component help function
308307
return info
309308

310309
def __repr__(self):
311-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
310+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
312311
maxlen = max(len(c) for c in comps) + 5
313312
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
314313
for c in comps:
315-
stats = tensorstats(getattr(self, c).value)
314+
stats = tensorstats(getattr(self, c).get())
316315
if stats is not None:
317316
line = [f"{k}: {v}" for k, v in stats.items()]
318317
line = ", ".join(line)

0 commit comments

Comments
 (0)