Skip to content

Commit a9831a8

Browse files
Add _initialize_with_initializer to PyTorch, NumPy, and OpenVINO backends
1 parent 951b5a2 commit a9831a8

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

keras/src/backend/numpy/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class Variable(KerasVariable):
2323
def _initialize(self, value):
2424
self._value = value
2525

26+
def _initialize_with_initializer(self, initializer):
27+
value = self._convert_to_tensor(
28+
initializer(self._shape, dtype=self._dtype)
29+
)
30+
self._initialize(value)
31+
2632
def _direct_assign(self, value):
2733
self._value = np.array(value, dtype=self._dtype)
2834

keras/src/backend/openvino/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,12 @@ def _initialize(self, value):
572572
)
573573
self._value = OpenVINOKerasTensor(value_const.output(0))
574574

575+
def _initialize_with_initializer(self, initializer):
576+
value = self._convert_to_tensor(
577+
initializer(self._shape, dtype=self._dtype)
578+
)
579+
self._initialize(value)
580+
575581
def _direct_assign(self, value):
576582
self._value = value
577583

keras/src/backend/torch/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def _initialize(self, value):
109109
requires_grad=self.trainable,
110110
).to(get_device())
111111

112+
def _initialize_with_initializer(self, initializer):
113+
value = self._convert_to_tensor(
114+
initializer(self._shape, dtype=self._dtype)
115+
)
116+
self._initialize(value)
117+
112118
def _direct_assign(self, value):
113119
with torch.no_grad():
114120
self.value.copy_(value)

0 commit comments

Comments
 (0)