1414 * See the License for the specific language governing permissions and
1515 * limitations under the License.
1616 */
17- package org .aika ;
17+ package network .aika ;
1818
19- import org .aika .lattice .AndNode ;
20- import org .aika .lattice .InputNode ;
21- import org .aika .lattice .Node ;
22- import org .aika .lattice . OrNode ;
23- import org .aika .neuron .INeuron ;
24- import org .aika .neuron .Synapse ;
19+ import network .aika .lattice .AndNode ;
20+ import network .aika .lattice .Node ;
21+ import network .aika .lattice .OrNode ;
22+ import network .aika .neuron . INeuron ;
23+ import network .aika .neuron .relation . Relation ;
24+ import network .aika .neuron .Synapse ;
2525
2626import java .util .*;
2727
28+ import static network .aika .neuron .activation .Range .Mapping .NONE ;
29+
2830/**
2931 * Converts the synapse weights of a neuron into a boolean logic representation of this neuron.
3032 *
3133 * @author Lukas Molzberger
3234 */
3335public class Converter {
3436
35- public static int MAX_AND_NODE_SIZE = 4 ;
37+ public static int MAX_AND_NODE_SIZE = 6 ;
3638
3739
3840 public static Comparator <Synapse > SYNAPSE_COMP = (s1 , s2 ) -> {
39- int r = Double .compare (s2 .weight , s1 .weight );
41+ int r = Boolean .compare (
42+ s2 .key .rangeOutput .begin != NONE || s2 .key .rangeOutput .end != NONE || s2 .key .identity ,
43+ s1 .key .rangeOutput .begin != NONE || s1 .key .rangeOutput .end != NONE || s1 .key .identity
44+ );
45+ if (r != 0 ) return r ;
46+ r = Double .compare (s2 .weight , s1 .weight );
4047 if (r != 0 ) return r ;
41- return Synapse . INPUT_SYNAPSE_COMP . compare (s1 , s2 );
48+ return Integer . compare (s1 . id , s2 . id );
4249 };
4350
44- private Model model ;
4551 private int threadId ;
4652 private INeuron neuron ;
4753 private Document doc ;
4854 private OrNode outputNode ;
4955 private Collection <Synapse > modifiedSynapses ;
5056
5157
52- public static boolean convert (Model m , int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
53- return new Converter (m , threadId , doc , neuron , modifiedSynapses ).convert ();
58+ public static boolean convert (int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
59+ return new Converter (threadId , doc , neuron , modifiedSynapses ).convert ();
5460 }
5561
5662
57- private Converter (Model model , int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
58- this .model = model ;
63+ private Converter (int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
5964 this .doc = doc ;
6065 this .neuron = neuron ;
6166 this .threadId = threadId ;
@@ -70,19 +75,13 @@ private boolean convert() {
7075
7176 if (neuron .biasSum + neuron .posDirSum + neuron .posRecSum <= 0.0 ) {
7277 neuron .requiredSum = neuron .posDirSum + neuron .posRecSum ;
73- outputNode .removeParents (threadId , false );
78+ outputNode .removeParents (threadId );
7479 return false ;
7580 }
7681
77- TreeSet <Synapse > tmp = new TreeSet <>(SYNAPSE_COMP );
78- for (Synapse s : neuron .inputSynapses .values ()) {
79- if (!s .isNegative () && !s .key .isRecurrent && !s .inactive ) {
80- tmp .add (s );
81- }
82- }
82+ List <Synapse > candidates = prepareCandidates ();
8383
84- Integer offset = null ;
85- Node requiredNode = null ;
84+ NodeContext nodeContext = null ;
8685 boolean noFurtherRefinement = false ;
8786 TreeSet <Synapse > reqSyns = new TreeSet <>(Synapse .INPUT_SYNAPSE_COMP );
8887 double sum = 0.0 ;
@@ -91,7 +90,7 @@ private boolean convert() {
9190 if (neuron .numDisjunctiveSynapses == 0 ) {
9291 double remainingSum = neuron .posDirSum ;
9392 int i = 0 ;
94- for (Synapse s : tmp ) {
93+ for (Synapse s : candidates ) {
9594 final boolean isOptionalInput = sum + remainingSum - s .weight + neuron .posRecSum + neuron .biasSum > 0.0 ;
9695 final boolean maxAndNodesReached = i >= MAX_AND_NODE_SIZE ;
9796 if (isOptionalInput || maxAndNodesReached ) {
@@ -102,8 +101,11 @@ private boolean convert() {
102101 neuron .requiredSum += s .weight ;
103102 reqSyns .add (s );
104103
105- requiredNode = getNextLevelNode (offset , requiredNode , s );
106- offset = Utils .nullSafeMin (s .key .relativeRid , offset );
104+ NodeContext nlNodeContext = expandNode (nodeContext , s );
105+ if (nlNodeContext == null ) {
106+ break ;
107+ }
108+ nodeContext = nlNodeContext ;
107109
108110 i ++;
109111
@@ -114,45 +116,76 @@ private boolean convert() {
114116 noFurtherRefinement = true ;
115117 break ;
116118 }
117- }
118119
119- outputNode .removeParents (threadId , false );
120- if (requiredNode != outputNode .requiredNode ) {
121- outputNode .requiredNode = requiredNode ;
122120 }
123121
122+ outputNode .removeParents (threadId );
123+
124124 if (noFurtherRefinement || i == MAX_AND_NODE_SIZE ) {
125- outputNode .addInput (offset , threadId , requiredNode , false );
125+ outputNode .addInput (nodeContext . getSynapseIds () , threadId , nodeContext . node );
126126 } else {
127- for (Synapse s : tmp ) {
127+ for (Synapse s : candidates ) {
128128 boolean belowThreshold = sum + s .weight + remainingSum + neuron .posRecSum + neuron .biasSum <= 0.0 ;
129129 if (belowThreshold ) {
130130 break ;
131131 }
132132
133133 if (!reqSyns .contains (s )) {
134- Node nln ;
135- nln = getNextLevelNode (offset , requiredNode , s );
136-
137- Integer nOffset = Utils .nullSafeMin (s .key .relativeRid , offset );
138- outputNode .addInput (nOffset , threadId , nln , false );
139- remainingSum -= s .weight ;
134+ NodeContext nlNodeContext = expandNode (nodeContext , s );
135+ if (nlNodeContext != null ) {
136+ outputNode .addInput (nlNodeContext .getSynapseIds (), threadId , nlNodeContext .node );
137+ remainingSum -= s .weight ;
138+ }
140139 }
141140 }
142141 }
143142 } else {
144143 for (Synapse s : modifiedSynapses ) {
145144 if (s .weight + neuron .posRecSum + neuron .biasSum > 0.0 ) {
146- Node nln = s .inputNode .get ();
147- offset = s .key .relativeRid ;
148- outputNode .addInput (offset , threadId , nln , false );
145+ NodeContext nlNodeContext = expandNode (nodeContext , s );
146+ outputNode .addInput (nlNodeContext .getSynapseIds (), threadId , nlNodeContext .node );
149147 }
150148 }
151149 }
152150
153151 return true ;
154152 }
155153
154+ private List <Synapse > prepareCandidates () {
155+ Synapse syn = getBestSynapse (neuron .inputSynapses .values ());
156+
157+ TreeSet <Integer > alreadyCollected = new TreeSet <>();
158+ ArrayList <Synapse > selectedCandidates = new ArrayList <>();
159+ TreeMap <Integer , Synapse > relatedSyns = new TreeMap <>();
160+ while (syn != null && selectedCandidates .size () < MAX_AND_NODE_SIZE ) {
161+ relatedSyns .remove (syn .id );
162+ selectedCandidates .add (syn );
163+ alreadyCollected .add (syn .id );
164+ for (Integer synId : syn .relations .keySet ()) {
165+ if (!alreadyCollected .contains (synId )) {
166+ relatedSyns .put (synId , syn .output .getSynapseById (synId ));
167+ }
168+ }
169+
170+ syn = getBestSynapse (relatedSyns .values ());
171+ }
172+
173+ return selectedCandidates ;
174+ }
175+
176+
177+ private Synapse getBestSynapse (Collection <Synapse > synapses ) {
178+ Synapse maxSyn = null ;
179+ for (Synapse s : synapses ) {
180+ if (!s .isNegative () && !s .key .isRecurrent && !s .inactive ) {
181+ if (maxSyn == null || SYNAPSE_COMP .compare (maxSyn , s ) > 0 ) {
182+ maxSyn = s ;
183+ }
184+ }
185+ }
186+ return maxSyn ;
187+ }
188+
156189
157190 public static final int DIRECT = 0 ;
158191 public static final int RECURRENT = 1 ;
@@ -171,14 +204,6 @@ private void initInputNodesAndComputeWeightSums() {
171204 INeuron in = s .input .get ();
172205 in .lock .acquireWriteLock ();
173206 try {
174- if (s .inputNode == null ) {
175- InputNode iNode = InputNode .add (model , s .key .createInputNodeKey (), s .input .get ());
176- iNode .setModified ();
177- iNode .setSynapse (s );
178- iNode .postCreate (doc );
179- s .inputNode = iNode .provider ;
180- }
181-
182207 if (!s .inactive ) {
183208 sumDelta [s .key .isRecurrent ? RECURRENT : DIRECT ][s .isNegative () ? NEGATIVE : POSITIVE ] -= s .weight ;
184209 sumDelta [s .key .isRecurrent ? RECURRENT : DIRECT ][s .getNewWeight () <= 0.0 ? NEGATIVE : POSITIVE ] += s .getNewWeight ();
@@ -225,13 +250,53 @@ private void initInputNodesAndComputeWeightSums() {
225250 }
226251
227252
228- private Node getNextLevelNode (Integer offset , Node requiredNode , Synapse s ) {
229- Node nln ;
230- if (requiredNode == null ) {
231- nln = s .inputNode .get ();
253+ private NodeContext expandNode (NodeContext nc , Synapse s ) {
254+ if (nc == null ) {
255+ NodeContext nln = new NodeContext ();
256+ nln .node = s .input .get ().outputNode .get ();
257+ nln .offsets = new Synapse [] {s };
258+ return nln ;
232259 } else {
233- nln = AndNode .createNextLevelNode (model , threadId , doc , requiredNode , new AndNode .Refinement (s .key .relativeRid , offset , s .inputNode ), null );
260+ Relation [] relations = new Relation [nc .offsets .length ];
261+ for (int i = 0 ; i < nc .offsets .length ; i ++) {
262+ Synapse linkedSynapse = nc .offsets [i ];
263+ relations [i ] = s .relations .get (linkedSynapse .id );
264+ }
265+
266+ NodeContext nln = new NodeContext ();
267+ nln .offsets = new Synapse [nc .offsets .length + 1 ];
268+ AndNode .Refinement ref = new AndNode .Refinement (new AndNode .RelationsMap (relations ), s .input .get ().outputNode );
269+ AndNode .RefValue rv = nc .node .extend (threadId , doc , ref );
270+ if (rv == null ) {
271+ return null ;
272+ }
273+
274+ nln .node = rv .child .get (doc );
275+
276+ for (int i = 0 ; i < nc .offsets .length ; i ++) {
277+ nln .offsets [rv .offsets [i ]] = nc .offsets [i ];
278+ }
279+ for (int i = 0 ; i < nln .offsets .length ; i ++) {
280+ if (nln .offsets [i ] == null ) {
281+ nln .offsets [i ] = s ;
282+ }
283+ }
284+ return nln ;
285+ }
286+ }
287+
288+
289+ private class NodeContext {
290+ Node node ;
291+
292+ Synapse [] offsets ;
293+
294+ int [] getSynapseIds () {
295+ int [] result = new int [offsets .length ];
296+ for (int i = 0 ; i < result .length ; i ++) {
297+ result [i ] = offsets [i ].id ;
298+ }
299+ return result ;
234300 }
235- return nln ;
236301 }
237302}
0 commit comments