Skip to content

Commit 92140b7

Browse files
FieldDefinition bindings
1 parent a4ef26b commit 92140b7

File tree

10 files changed

+96
-29
lines changed

10 files changed

+96
-29
lines changed

include/fields/abstract_function_definition.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class AbstractFunctionDefinition : public FieldDefinition {
1010
public:
1111
// Constructors
12-
AbstractFunctionDefinition(Type* objectType, const std::string& name, int* numArgs, double tolerance);
12+
AbstractFunctionDefinition(Type* objectType, const std::string& name, int numArgs, double tolerance);
1313

1414
// Destructor
1515
virtual ~AbstractFunctionDefinition() = default;

include/fields/field_definition.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ class FieldDefinition {
3535

3636
public:
3737

38-
FieldDefinition(Type* objectType, const std::string& name, int* numArgs, double tolerance);
38+
FieldDefinition(Type* objectType, const std::string& name, int numArgs, double tolerance);
39+
virtual ~FieldDefinition() = default;
3940

4041
void setFieldId(int fieldId);
41-
virtual void transmit(Field* targetField, FieldLinkDefinition* fieldLink, double update);
42+
virtual void transmit(Field* targetField, FieldLinkDefinition* fieldLink, double update) = 0;
4243
void receiveUpdate(Field* field, double update);
4344

4445
FieldDefinition* getParent() const;
@@ -54,6 +55,7 @@ class FieldDefinition {
5455
void addOutput(FieldLinkDefinition* fl);
5556
std::vector<FieldLinkDefinition*> getOutputs();
5657

58+
FieldDefinition& in(Relation* relation, FieldDefinition* input, int arg);
5759
FieldDefinition& out(Relation* relation, FieldDefinition* output, int arg);
5860

5961
FieldDefinition& setName(const std::string& name);

include/fields/input_field.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "fields/obj.h"
99

1010
class InputField : public AbstractFunctionDefinition {
11-
static int zero;
1211

1312
public:
1413

include/fields/subtraction.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef SUBTRACTION_H
2+
#define SUBTRACTION_H
3+
4+
#include "fields/abstract_function_definition.h"
5+
#include "fields/type.h"
6+
#include "fields/obj.h"
7+
8+
9+
class Subtraction : public AbstractFunctionDefinition {
10+
public:
11+
12+
// Constructor
13+
Subtraction(Type* ref, const std::string& name);
14+
15+
// Overridden method from AbstractFunctionDefinition
16+
double computeUpdate(Obj* obj, FieldLinkDefinition* fl, double u) override;
17+
};
18+
19+
#endif // SUBTRACTION_H

src/fields/abstract_function_definition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
// Constructor with 4 arguments
5-
AbstractFunctionDefinition::AbstractFunctionDefinition(Type* objectType, const std::string& name, int* numArgs, double tolerance)
5+
AbstractFunctionDefinition::AbstractFunctionDefinition(Type* objectType, const std::string& name, int numArgs, double tolerance)
66
: FieldDefinition(objectType, name, numArgs, tolerance) {}
77

88
// Transmit method

src/fields/field_definition.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
#include "fields/obj.h"
99

1010

11-
FieldDefinition::FieldDefinition(Type* objectType, const std::string& name, int* numArgs, double tolerance)
11+
FieldDefinition::FieldDefinition(Type* objectType, const std::string& name, int numArgs, double tolerance)
1212
: objectType(objectType), name(name), isNextRound(false) {
1313
this->tolerance = tolerance;
1414
objectType->setFieldDefinition(this);
1515

16-
if (numArgs != nullptr) {
17-
inputs.reserve(*numArgs);
16+
if (numArgs > 0) {
17+
inputs.reserve(numArgs);
1818
}
1919
}
2020

@@ -81,6 +81,12 @@ std::vector<FieldLinkDefinition*> FieldDefinition::getOutputs() {
8181
return outputs;
8282
}
8383

84+
FieldDefinition& FieldDefinition::in(Relation* relation, FieldDefinition* input, int arg) {
85+
FieldLinkDefinition::link(input, this, relation, arg);
86+
// assert(relation || objectType->isInstanceOf(output->getObjectType()) || output->getObjectType()->isInstanceOf(objectType));
87+
return *this;
88+
}
89+
8490
FieldDefinition& FieldDefinition::out(Relation* relation, FieldDefinition* output, int arg) {
8591
FieldLinkDefinition::link(this, output, relation, arg);
8692
// assert(relation || objectType->isInstanceOf(output->getObjectType()) || output->getObjectType()->isInstanceOf(objectType));

src/fields/input_field.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
// Constructor
55
InputField::InputField(Type* ref, const std::string& name)
6-
: AbstractFunctionDefinition(ref, name, &zero, 0.0) {}
6+
: AbstractFunctionDefinition(ref, name, 0, 0.0) {}
77

88
// Overridden computeUpdate method
99
double InputField::computeUpdate(Obj* obj, FieldLinkDefinition* fl, double u) {

src/fields/subtraction.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "fields/subtraction.h"
2+
3+
4+
// Constructor
5+
Subtraction::Subtraction(Type* ref, const std::string& name)
6+
: AbstractFunctionDefinition(ref, name, 0, 0.0) {}
7+
8+
// Overridden computeUpdate method
9+
double Subtraction::computeUpdate(Obj* obj, FieldLinkDefinition* fl, double u) {
10+
return (fl->getArgument() == 0) ? u : -u;
11+
}
Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include <pybind11/pybind11.h>
22

3+
#include "fields/field_definition.h"
34
#include "fields/field_update.h"
45
#include "fields/type.h"
56
#include "fields/type_registry.h"
67
#include "fields/input_field.h"
8+
#include "fields/subtraction.h"
79

810

911
// ----------------
@@ -14,23 +16,51 @@ namespace py = pybind11;
1416

1517
PYBIND11_MODULE(aika, m)
1618
{
17-
py::class_<FieldUpdate>(m, "FieldUpdate")
18-
.def(py::init<ProcessingPhase&, QueueInterceptor*>());
19-
20-
py::class_<Type>(m, "Type")
21-
.def(py::init<TypeRegistry*, const std::string&>())
22-
.def("__str__", [](const Type &t) {
23-
return t.toString();
24-
})
25-
.def("inputField", [](const Type &ref, const std::string &name) {
26-
return new InputField(
27-
const_cast<Type*>(&ref),
28-
name
29-
);
30-
});
31-
32-
py::class_<TypeRegistry>(m, "TypeRegistry")
33-
.def(py::init<>())
34-
.def("getType", &TypeRegistry::getType)
35-
.def("registerType", &TypeRegistry::registerType);
19+
// Bind Relation
20+
py::class_<Relation>(m, "Relation");
21+
22+
py::class_<FieldUpdate>(m, "FieldUpdate")
23+
.def(py::init<ProcessingPhase&, QueueInterceptor*>());
24+
25+
// Bind FieldDefinition first
26+
py::class_<FieldDefinition>(m, "FieldDefinition")
27+
.def("in", &FieldDefinition::in, py::return_value_policy::reference_internal,
28+
py::arg("relation"), py::arg("input"), py::arg("arg"))
29+
.def("out", &FieldDefinition::out, py::return_value_policy::reference_internal,
30+
py::arg("relation"), py::arg("output"), py::arg("arg"));
31+
32+
// Bind AbstractFunctionDefinition (inherits from FieldDefinition)
33+
py::class_<AbstractFunctionDefinition, FieldDefinition>(m, "AbstractFunctionDefinition");
34+
35+
// Bind Subtraction (inherits from AbstractFunctionDefinition)
36+
py::class_<Subtraction, AbstractFunctionDefinition>(m, "Subtraction");
37+
38+
py::class_<InputField>(m, "InputField")
39+
.def(py::init<Type*, const std::string &>())
40+
.def("__str__", [](const InputField &f) {
41+
return f.toString();
42+
});
43+
44+
py::class_<Type>(m, "Type")
45+
.def(py::init<TypeRegistry*, const std::string&>())
46+
.def("__str__", [](const Type &t) {
47+
return t.toString();
48+
})
49+
.def("inputField", [](const Type &ref, const std::string &name) {
50+
return new InputField(
51+
const_cast<Type*>(&ref),
52+
name
53+
);
54+
}, py::return_value_policy::take_ownership)
55+
.def("sub", [](const Type &ref, const std::string &name) {
56+
return new Subtraction(
57+
const_cast<Type*>(&ref),
58+
name
59+
);
60+
}, py::return_value_policy::take_ownership);
61+
62+
py::class_<TypeRegistry>(m, "TypeRegistry")
63+
.def(py::init<>())
64+
.def("getType", &TypeRegistry::getType)
65+
.def("registerType", &TypeRegistry::registerType);
3666
}

tests/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
tr = aika.TypeRegistry()
99

10-
t =aika.Type(tr)
10+
t = aika.Type(tr)
1111

0 commit comments

Comments
 (0)