|
1 | 1 | __all__ = [ |
2 | 2 | "constant", |
| 3 | + "zeros", |
| 4 | + "ones", |
3 | 5 | "diag", |
4 | 6 | "identity", |
5 | 7 | "iota", |
|
17 | 19 | "flip", |
18 | 20 | "join", |
19 | 21 | "moddims", |
| 22 | + "reshape", |
20 | 23 | "reorder", |
21 | 24 | "replace", |
22 | 25 | "select", |
@@ -70,6 +73,59 @@ def constant(scalar: int | float | complex, shape: tuple[int, ...] = (1,), dtype |
70 | 73 | """ |
71 | 74 | return cast(Array, wrapper.create_constant_array(scalar, shape, dtype)) |
72 | 75 |
|
| 76 | +def zeros(shape: tuple[int, ...], dtype: Dtype = float32) -> Array: |
| 77 | + """ |
| 78 | + Create a multi-dimensional array filled with zeros |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + shape : tuple[int, ...], optional, default: (1,) |
| 83 | + The shape of the constant array. |
| 84 | +
|
| 85 | + dtype : Dtype, optional, default: float32 |
| 86 | + Data type of the array. |
| 87 | +
|
| 88 | + Returns |
| 89 | + ------- |
| 90 | + Array |
| 91 | + A multi-dimensional ArrayFire array filled zeros |
| 92 | +
|
| 93 | + Notes |
| 94 | + ----- |
| 95 | + The shape parameter determines the dimensions of the resulting array: |
| 96 | + - If shape is (x1,), the output is a 1D array of size (x1,). |
| 97 | + - If shape is (x1, x2), the output is a 2D array of size (x1, x2). |
| 98 | + - If shape is (x1, x2, x3), the output is a 3D array of size (x1, x2, x3). |
| 99 | + - If shape is (x1, x2, x3, x4), the output is a 4D array of size (x1, x2, x3, x4). |
| 100 | + """ |
| 101 | + return constant(0, shape, dtype) |
| 102 | + |
| 103 | +def ones(shape: tuple[int, ...], dtype: Dtype = float32) -> Array: |
| 104 | + """ |
| 105 | + Create a multi-dimensional array filled with ones |
| 106 | +
|
| 107 | + Parameters |
| 108 | + ---------- |
| 109 | + shape : tuple[int, ...], optional, default: (1,) |
| 110 | + The shape of the constant array. |
| 111 | +
|
| 112 | + dtype : Dtype, optional, default: float32 |
| 113 | + Data type of the array. |
| 114 | +
|
| 115 | + Returns |
| 116 | + ------- |
| 117 | + Array |
| 118 | + A multi-dimensional ArrayFire array filled ones |
| 119 | +
|
| 120 | + Notes |
| 121 | + ----- |
| 122 | + The shape parameter determines the dimensions of the resulting array: |
| 123 | + - If shape is (x1,), the output is a 1D array of size (x1,). |
| 124 | + - If shape is (x1, x2), the output is a 2D array of size (x1, x2). |
| 125 | + - If shape is (x1, x2, x3), the output is a 3D array of size (x1, x2, x3). |
| 126 | + - If shape is (x1, x2, x3, x4), the output is a 4D array of size (x1, x2, x3, x4). |
| 127 | + """ |
| 128 | + return constant(1, shape, dtype) |
73 | 129 |
|
74 | 130 | @afarray_as_array |
75 | 131 | def diag(array: Array, /, *, diag_index: int = 0, extract: bool = True) -> Array: |
@@ -255,8 +311,7 @@ def lower(array: Array, /, *, is_unit_diag: bool = False) -> Array: |
255 | 311 | Notes |
256 | 312 | ----- |
257 | 313 | - The function does not alter the elements above the main diagonal; it simply does not include them in the output. |
258 | | - - This function can be useful for mathematical operations that require lower triangular matrices, such as certain |
259 | | - types of matrix factorizations. |
| 314 | + - This function can be useful for mathematical operations that require lower triangular matrices, such as certain types of matrix factorizations. |
260 | 315 |
|
261 | 316 | Examples |
262 | 317 | -------- |
@@ -312,8 +367,7 @@ def upper(array: Array, /, *, is_unit_diag: bool = False) -> Array: |
312 | 367 | Notes |
313 | 368 | ----- |
314 | 369 | - The function does not alter the elements below the main diagonal; it simply does not include them in the output. |
315 | | - - This function can be useful for mathematical operations that require upper triangular matrices, such as certain |
316 | | - types of matrix factorizations. |
| 370 | + - This function can be useful for mathematical operations that require upper triangular matrices, such as certain types of matrix factorizations. |
317 | 371 |
|
318 | 372 | Examples |
319 | 373 | -------- |
@@ -818,6 +872,40 @@ def moddims(array: Array, shape: tuple[int, ...], /) -> Array: |
818 | 872 | # TODO add examples to doc |
819 | 873 | return cast(Array, wrapper.moddims(array.arr, shape)) |
820 | 874 |
|
| 875 | +def reshape(array: Array, shape: tuple[int, ...], /) -> Array: |
| 876 | + """ |
| 877 | + Modify the shape of the array without changing the data layout. |
| 878 | +
|
| 879 | + Parameters |
| 880 | + ---------- |
| 881 | + array : af.Array |
| 882 | + Multi-dimensional array to be reshaped. |
| 883 | +
|
| 884 | + shape : tuple of int |
| 885 | + The desired shape of the output array. It should be a tuple of integers |
| 886 | + representing the dimensions of the output array. The product of these |
| 887 | + dimensions must match the total number of elements in the input array. |
| 888 | +
|
| 889 | + Returns |
| 890 | + ------- |
| 891 | + out : af.Array |
| 892 | + - An array containing the same data as `array` with the specified shape. |
| 893 | + - The total number of elements in `array` must match the product of the |
| 894 | + dimensions specified in the `shape` tuple. |
| 895 | +
|
| 896 | + Raises |
| 897 | + ------ |
| 898 | + ValueError |
| 899 | + If the total number of elements in the input array does not match the |
| 900 | + product of the dimensions specified in the `shape` tuple. |
| 901 | +
|
| 902 | + Notes |
| 903 | + ----- |
| 904 | + This function modifies the shape of the input array without changing the |
| 905 | + data layout. The resulting array will have the same data, but with a |
| 906 | + different shape as specified by the `shape` parameter. |
| 907 | + """ |
| 908 | + return moddims(array, shape) |
821 | 909 |
|
822 | 910 | @afarray_as_array |
823 | 911 | def reorder(array: Array, /, *, shape: tuple[int, ...] = (1, 0, 2, 3)) -> Array: |
|
0 commit comments