|
16 | 16 | #include "fields/division.h" |
17 | 17 | #include "fields/exponential_function.h" |
18 | 18 | #include "fields/summation.h" |
| 19 | +#include "fields/field_activation_function.h" |
19 | 20 |
|
20 | 21 |
|
21 | 22 | // ---------------- |
@@ -68,6 +69,10 @@ PYBIND11_MODULE(aika, m) |
68 | 69 | // Bind Summation (inherits from AbstractFunctionDefinition) |
69 | 70 | py::class_<Summation, AbstractFunctionDefinition>(m, "Summation"); |
70 | 71 |
|
| 72 | + // Bind FieldActivationFunction (inherits from AbstractFunctionDefinition) |
| 73 | + py::class_<FieldActivationFunction, AbstractFunctionDefinition>(m, "FieldActivationFunction") |
| 74 | + .def(py::init<Type*, const std::string&, ActivationFunction*, double>()); |
| 75 | + |
71 | 76 | py::class_<InputField, FieldDefinition>(m, "InputField") |
72 | 77 | .def(py::init<Type*, const std::string &>()) |
73 | 78 | .def("__str__", [](const InputField &f) { |
@@ -120,6 +125,14 @@ PYBIND11_MODULE(aika, m) |
120 | 125 | const_cast<Type*>(&ref), |
121 | 126 | name |
122 | 127 | ); |
| 128 | + }, py::return_value_policy::reference_internal) |
| 129 | + .def("fieldActivationFunc", [](const Type &ref, const std::string &name, ActivationFunction* actFunction, double tolerance) { |
| 130 | + return new FieldActivationFunction( |
| 131 | + const_cast<Type*>(&ref), |
| 132 | + name, |
| 133 | + actFunction, |
| 134 | + tolerance |
| 135 | + ); |
123 | 136 | }, py::return_value_policy::reference_internal); |
124 | 137 |
|
125 | 138 | py::class_<Obj>(m, "Obj") |
|
0 commit comments