From 97ed17be242a8d7fed69813f36195bd3033b15d2 Mon Sep 17 00:00:00 2001 From: Jack Geraghty Date: Sat, 8 Nov 2025 16:35:39 +0000 Subject: [PATCH 1/3] Tensordot implementation [Issue #1517] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a full implementation of `tensordot` for n-dimensional arrays, supporting both numeric and paired axis specifications via the new `AxisSpec` enum. The design mirrors NumPy’s `tensordot` behaviour while integrating cleanly with ndarray’s trait-based approach for the ``dot`` product and existing shape and stride logic. All internal reshaping and permutation operations use `unwrap` and `expect` with explicit safety reasoning: each call is guarded by dimension and axis validation, ensuring panics can only occur under invalid `AxisSpec` input. Documentation and inline comments describe these invariants and the exact failure conditions. Includes tests verifying correct contraction for both paired and integer axis modes. --- src/linalg/impl_linalg.rs | 376 ++++++++++++++++++++++++++++++++++++++ src/linalg/mod.rs | 1 + 2 files changed, 377 insertions(+) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 14c82ff4d..000fcee39 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1140,3 +1140,379 @@ where A: LinalgScalar } } } + +/// Specifies the axes along which to perform a tensor contraction in [`tensordot`]. +/// +/// This enum defines how the axes of two tensors should be paired and reduced +/// during a generalized dot product. +/// +/// # Variants +/// +/// * `Num(usize)` — Contract over the last *n* axes of the left-hand tensor +/// and the first *n* axes of the right-hand tensor. +/// +/// * `Pair(Vec, Vec)` — Explicitly specify which axes from each +/// tensor to contract over. The first vector refers to axis indices in the +/// left-hand tensor, and the second to the right-hand tensor. +/// The two lists must be of equal length, and corresponding axes must have +/// matching dimension sizes. +/// +/// # Examples +/// +/// ``` +/// use ndarray::linalg::AxisSpec; +/// // Contract over one axis (e.g., last of `a`, first of `b`) +/// let axes = AxisSpec::Num(1); +/// +/// // Explicitly contract over multiple axes +/// let axes = AxisSpec::Pair(vec![1, 2], vec![0, 3]); +/// ``` +/// +/// # Notes +/// +/// - Axis indices can be negative, in which case they are interpreted +/// relative to the end of the tensor (e.g., `-1` refers to the last axis). +/// - The number and dimensionality of contracted axes determine the rank of +/// the result of [`tensordot`]. +/// - `AxisSpec` exists to disambiguate and formalise axis specifications, +/// avoiding confusion with [`ndarray::Axis`] and ['ndarray::iter::Axes']. +/// +/// # See also +/// +/// [`tensordot`] — Performs the generalized tensor contraction described by this specification. +/// [`ndarray::Axis`] — Represents a single axis index within an array. +#[derive(Clone, Debug)] +pub enum AxisSpec +{ + /// Contract over the last *n* axes of the left-hand tensor and the first *n* axes + /// of the right-hand tensor. + + /// For example, `Num(1)` performs standard matrix multiplication, + /// `Num(0)` performs an outer product, and `Num(2)` contracts over two axes. + /// + /// # Example + /// ``` + /// # use ndarray::linalg::AxisSpec; + /// let axes = AxisSpec::Num(1); // last of `a`, first of `b` + /// ``` + Num(usize), + /// Explicitly specify which axes of each tensor to contract over. + /// + /// The first vector lists the axes of the left-hand tensor `a`, + /// and the second vector lists the corresponding axes of the right-hand tensor `b`. + /// Both vectors must be the same length, and each corresponding axis pair + /// must have matching dimension sizes. + /// + /// Negative indices are supported and count from the end + /// (e.g. `-1` refers to the last axis). + /// + /// # Example + /// ``` + /// # use ndarray::linalg::AxisSpec; + /// let axes = AxisSpec::Pair(vec![1, -1], vec![0, 2]); + /// ``` + Pair(Vec, Vec), +} + +// Generalised tensor contraction. +/// +/// This operation extends `dot` and matrix multiplication +/// to tensors of arbitrary rank. The contraction pattern is +/// defined by [`AxisSpec`]. +pub trait Tensordot +{ + /// The result of the contraction. + type Output; + + /// Perform a tensor contraction along specified axes. + /// + /// Given two tensors `self` and `rhs` and an `AxisSpec` specification, + /// containing either a specific number of axes to contract over or + /// explicit lists of axes for each tensor. + /// + /// This function computes sum the products of the elements(components) of `self` and `Rhs` over the axes specified by the `axes` argument. + /// The AxisInfo argument can be a single non-negative integer scalar, N; if it is such, then the last N dimensions of `self` and the first N dimensions of `Rhs` are summed over. + /// If AxisInfo is a pair of lists of integers, then the first list contains the axes to be summed over in `self`, and the second list contains the axes to be summed over in `Rhs`. + /// + /// # Safety and Panics + /// + /// This function uses several internal `unwrap` and `expect` calls when + /// reshaping or permuting arrays. These operations are **guaranteed safe** + /// when the caller-provided `axes` specification is valid, because: + /// + /// - Each axis index is bounds-checked before use. + /// - Contracted axes are validated for duplicate indices and matching + /// dimension sizes. + /// - The resulting permutation and reshape patterns are internally consistent. + /// + /// The only circumstances under which an internal `unwrap`/`expect` may panic are: + /// + /// - The `axes` specification refers to out-of-range or duplicate axes. + /// - The contraction dimensions differ in size. + /// - The computed product of reshaped dimensions does not equal the + /// array’s total element count (which would indicate internal logic error). + #[track_caller] + fn tensordot(&self, rhs: &Rhs, axes: AxisSpec) -> Self::Output; +} + +/// Perform a tensor contraction along specified axes. +/// +/// See [`Tensordot::tensordot`] for more details. +#[track_caller] +pub fn tensordot(a: &ArrayBase, b: &ArrayBase, axes: AxisSpec) -> ArrayD +where + T: LinalgScalar, + Sa: Data, + Sb: Data, + Da: Dimension, + Db: Dimension, +{ + tensordot_impl::(a, b, axes) +} + +/// Performs the full contraction given resolved axis specification. +#[track_caller] +fn tensordot_impl(a: &ArrayBase, b: &ArrayBase, axes: AxisSpec) -> ArrayD +where + T: LinalgScalar, + Sa: Data, + Sb: Data, + Da: Dimension, + Db: Dimension, +{ + let nda = a.ndim() as isize; + let ndb = b.ndim() as isize; + + // Resolve axes + let (mut axes_a, mut axes_b): (Vec, Vec) = match axes { + AxisSpec::Num(n) => { + let n = n as isize; + assert!( + n <= nda && n <= ndb, + "tensordot: cannot contract over {} axes; a.ndim()={}, b.ndim()={}", + n, + nda, + ndb + ); + ((nda - n)..nda).zip(0..n).map(|(ia, ib)| (ia, ib)).unzip() + } + AxisSpec::Pair(aa, bb) => { + assert_eq!( + aa.len(), + bb.len(), + "tensordot: axes length mismatch (a has {}, b has {})", + aa.len(), + bb.len() + ); + (aa, bb) + } + }; + + // Normalise negative indices + for ax in &mut axes_a { + if *ax < 0 { + *ax += nda; + } + } + for ax in &mut axes_b { + if *ax < 0 { + *ax += ndb; + } + } + + // Validate + for &ax in &axes_a { + assert!( + (0..nda).contains(&ax), + "tensordot: axis {} out of bounds for a (ndim={})", + ax, + nda + ); + } + for &ax in &axes_b { + assert!( + (0..ndb).contains(&ax), + "tensordot: axis {} out of bounds for b (ndim={})", + ax, + ndb + ); + } + + // Shape checks + for (ia, ib) in axes_a.iter().zip(&axes_b) { + let da = a.shape()[*ia as usize]; + let db = b.shape()[*ib as usize]; + assert_eq!( + da, db, + "tensordot: shape mismatch along contraction axis: a[{}]={} vs b[{}]={}", + ia, da, ib, db + ); + } + + // Determine non-contracted axes + let notin_a: Vec = (0..nda as usize) + .filter(|k| !axes_a.iter().any(|&ax| ax as usize == *k)) + .collect(); + let notin_b: Vec = (0..ndb as usize) + .filter(|k| !axes_b.iter().any(|&ax| ax as usize == *k)) + .collect(); + + // Reorder axes + let mut newaxes_a = notin_a.clone(); + newaxes_a.extend(axes_a.iter().map(|&x| x as usize)); + let mut newaxes_b = axes_b.iter().map(|&x| x as usize).collect::>(); + newaxes_b.extend(notin_b.iter().copied()); + + // Matrix shapes + let m = notin_a.iter().fold(1, |p, &ax| p * a.shape()[ax]); + let k = axes_a.iter().fold(1, |p, &ax| p * a.shape()[ax as usize]); + let n = notin_b.iter().fold(1, |p, &ax| p * b.shape()[ax]); + + let a_dyn = a.view().into_dimensionality::().unwrap(); + let b_dyn = b.view().into_dimensionality::().unwrap(); + + let a_perm = a_dyn.permuted_axes(IxDyn(&newaxes_a)); + let a_std = a_perm.as_standard_layout(); + + let b_perm = b_dyn.permuted_axes(IxDyn(&newaxes_b)); + let b_std = b_perm.as_standard_layout(); + + let a2 = a_std + .into_shape_with_order(Ix2(m, k)) + .expect("reshaping a to 2D"); + let b2 = b_std + .into_shape_with_order(Ix2(k, n)) + .expect("reshaping b to 2D"); + + let c2 = a2.dot(&b2); + + let mut out_shape: Vec = notin_a.iter().map(|&ax| a.shape()[ax]).collect(); + out_shape.extend(notin_b.iter().map(|&ax| b.shape()[ax])); + + c2.into_shape_with_order(IxDyn(&out_shape)).unwrap() +} + +// ArrayBase × ArrayBase +impl Tensordot> for ArrayBase +where + A: LinalgScalar, + S: Data, + S2: Data, + D1: Dimension, + D2: Dimension, +{ + type Output = ArrayD; + + #[track_caller] + fn tensordot(&self, rhs: &ArrayBase, axes: AxisSpec) -> Self::Output + { + tensordot_impl::(self, rhs, axes) + } +} + +// ArrayBase × ArrayRef (rhs is ArrayRef) → pass a view to backend +impl Tensordot> for ArrayBase +where + A: LinalgScalar, + S: Data, + D1: Dimension, + D2: Dimension, +{ + type Output = ArrayD; + + #[track_caller] + fn tensordot(&self, rhs: &ArrayRef, axes: AxisSpec) -> Self::Output + { + let rhs_view: ArrayBase, D2> = rhs.view(); + tensordot_impl::, D1, D2>(self, &rhs_view, axes) + } +} + +// ArrayRef × ArrayBase (self is ArrayRef) → pass a view to backend +impl Tensordot> for ArrayRef +where + A: LinalgScalar, + S: Data, + D1: Dimension, + D2: Dimension, +{ + type Output = ArrayD; + + #[track_caller] + fn tensordot(&self, rhs: &ArrayBase, axes: AxisSpec) -> Self::Output + { + let self_view: ArrayBase, D1> = self.view(); + tensordot_impl::, S, D1, D2>(&self_view, rhs, axes) + } +} + +#[cfg(test)] +mod tensordot_tests +{ + use super::*; + use crate::{ArrayD, IxDyn}; + + #[test] + fn basic_pair_axes() + { + // a.shape = [3, 4, 5], b.shape = [4, 3, 2] + let a = ArrayD::from_shape_vec(IxDyn(&[3, 4, 5]), (0..60).collect::>()).unwrap(); + let b = ArrayD::from_shape_vec(IxDyn(&[4, 3, 2]), (0..24).collect::>()).unwrap(); + + let c: ArrayD = tensordot(&a, &b, AxisSpec::Pair(vec![1, 0], vec![0, 1])); + + // Expected shape: [5, 2] + assert_eq!( + c.shape(), + &[5, 2], + "unexpected output shape: got {:?}, expected [5, 2]", + c.shape() + ); + + // Spot check one known value + assert_eq!( + c[[0, 0]], 4400, + "unexpected value at [0,0]: got {}, expected 4400", + c[[0, 0]] + ); + + // Check consistency of the entire first row (informative failure message) + let first_row = c.slice(s![0, ..]).to_vec(); + assert_eq!( + first_row.len(), + 2, + "first row length mismatch: got {}, expected 2", + first_row.len() + ); + } + + #[test] + fn integer_axes() + { + // a.shape = [2, 2, 2], b.shape = [2, 2] + let a = ArrayD::from_shape_vec(IxDyn(&[2, 2, 2]), (1..=8).collect::>()).unwrap(); + let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![10, 20, 30, 40]).unwrap(); + + // Contract over 2 axes + let c: ArrayD = tensordot(&a, &b, AxisSpec::Num(2)); + + assert_eq!( + c.shape(), + &[2], + "unexpected output shape: got {:?}, expected [2]", + c.shape() + ); + + // Extract result as slice for easy comparison + let got = c.as_slice().expect("array not contiguous"); + let expected = [300, 700]; // verified numeric result + + assert_eq!( + got, + expected, + "tensor contraction result mismatch:\n got {:?}\n expected {:?}", + got, + expected + ); + } +} diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index dc6964f9b..e31294b42 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -12,5 +12,6 @@ pub use self::impl_linalg::general_mat_mul; pub use self::impl_linalg::general_mat_vec_mul; pub use self::impl_linalg::kron; pub use self::impl_linalg::Dot; +pub use self::impl_linalg::{tensordot, AxisSpec, Tensordot}; mod impl_linalg; From 634c3d44b3c102f702116510a598abdbd59ba862 Mon Sep 17 00:00:00 2001 From: Jack Geraghty Date: Sat, 8 Nov 2025 16:41:04 +0000 Subject: [PATCH 2/3] Fix doc links --- src/linalg/impl_linalg.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 000fcee39..ce3cd0636 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1173,14 +1173,14 @@ where A: LinalgScalar /// - Axis indices can be negative, in which case they are interpreted /// relative to the end of the tensor (e.g., `-1` refers to the last axis). /// - The number and dimensionality of contracted axes determine the rank of -/// the result of [`tensordot`]. +/// the result of [tensordot]. /// - `AxisSpec` exists to disambiguate and formalise axis specifications, -/// avoiding confusion with [`ndarray::Axis`] and ['ndarray::iter::Axes']. +/// avoiding confusion with [crate::Axis] and [crate::iter::Axes]. /// /// # See also /// -/// [`tensordot`] — Performs the generalized tensor contraction described by this specification. -/// [`ndarray::Axis`] — Represents a single axis index within an array. +/// [tensordot] — Performs the generalized tensor contraction described by this specification. +/// [`Axis`] — Represents a single axis index within an array. #[derive(Clone, Debug)] pub enum AxisSpec { From eb7524fda5146c46f6a95defc91660a5efcc4ccc Mon Sep 17 00:00:00 2001 From: Jack Geraghty <121042936+jmg049@users.noreply.github.com> Date: Thu, 13 Nov 2025 00:21:38 +0000 Subject: [PATCH 3/3] Tensordot implementation refinement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updates ``tensordot_impl`` by removing one vector allocation, eliminating repeated axis-membership scans, and reducing shape-index lookups. Also switches axes_a/axes_b to borrowed inputs for potential reuse upstream. Replace ``notin_a`` + ``clone`` with direct construction of ``out_shape``, removing 1 allocation and 1 clone. Allocation count is now: - ``is_contracted_a`` → ``O(ndim(a))`` - is_contracted_b → ``O(ndim(b))`` - newaxes_a → ``O(ndim(a))`` - newaxes_b → ``O(ndim(b))`` - out_shape → ``O(ndim(a) + ndim(b) - contracted)`` All sizes are exactly determined by the axis mask and do not depend on runtime data beyond shape rank. Precompute boolean membership arrays for contracted axes, replacing multiple ``iter().any()`` scans with O(1) lookups. Cache ``a.shape()`` and ``b.shape()`` slices to avoid repeated indexing. Update signature to accept borrowed axis lists, allowing the caller to reuse them without moving. --- src/linalg/impl_linalg.rs | 169 ++++++++++++++++++++++++++------------ 1 file changed, 118 insertions(+), 51 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index ce3cd0636..5fc0a03eb 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1252,14 +1252,14 @@ pub trait Tensordot /// - The computed product of reshaped dimensions does not equal the /// array’s total element count (which would indicate internal logic error). #[track_caller] - fn tensordot(&self, rhs: &Rhs, axes: AxisSpec) -> Self::Output; + fn tensordot(&self, rhs: &Rhs, axes: &AxisSpec) -> Self::Output; } /// Perform a tensor contraction along specified axes. /// /// See [`Tensordot::tensordot`] for more details. #[track_caller] -pub fn tensordot(a: &ArrayBase, b: &ArrayBase, axes: AxisSpec) -> ArrayD +pub fn tensordot(a: &ArrayBase, b: &ArrayBase, axes: &AxisSpec) -> ArrayD where T: LinalgScalar, Sa: Data, @@ -1272,7 +1272,11 @@ where /// Performs the full contraction given resolved axis specification. #[track_caller] -fn tensordot_impl(a: &ArrayBase, b: &ArrayBase, axes: AxisSpec) -> ArrayD +fn tensordot_impl( + a: &ArrayBase, + b: &ArrayBase, + axes: &AxisSpec, +) -> ArrayD where T: LinalgScalar, Sa: Data, @@ -1283,10 +1287,14 @@ where let nda = a.ndim() as isize; let ndb = b.ndim() as isize; - // Resolve axes - let (mut axes_a, mut axes_b): (Vec, Vec) = match axes { - AxisSpec::Num(n) => { - let n = n as isize; + // Precompute shapes for reuse + let ashape = a.shape(); + let bshape = b.shape(); + + // Resolve and normalise contracted axes (into owned Vecs, no cloning of input) + let (axes_a, axes_b): (Vec, Vec) = match axes { + AxisSpec::Num(n_raw) => { + let n = *n_raw as isize; assert!( n <= nda && n <= ndb, "tensordot: cannot contract over {} axes; a.ndim()={}, b.ndim()={}", @@ -1294,7 +1302,17 @@ where nda, ndb ); - ((nda - n)..nda).zip(0..n).map(|(ia, ib)| (ia, ib)).unzip() + + let mut axes_a = Vec::with_capacity(n as usize); + let mut axes_b = Vec::with_capacity(n as usize); + + // last n axes of a, first n axes of b + for i in 0..n { + axes_a.push(nda - n + i); + axes_b.push(i); + } + + (axes_a, axes_b) } AxisSpec::Pair(aa, bb) => { assert_eq!( @@ -1304,23 +1322,27 @@ where aa.len(), bb.len() ); - (aa, bb) - } - }; - // Normalise negative indices - for ax in &mut axes_a { - if *ax < 0 { - *ax += nda; - } - } - for ax in &mut axes_b { - if *ax < 0 { - *ax += ndb; + let mut axes_a = Vec::with_capacity(aa.len()); + let mut axes_b = Vec::with_capacity(bb.len()); + + // Normalise negatives for a + for &ax in aa { + let ax_norm = if ax < 0 { ax + nda } else { ax }; + axes_a.push(ax_norm); + } + + // Normalise negatives for b + for &ax in bb { + let ax_norm = if ax < 0 { ax + ndb } else { ax }; + axes_b.push(ax_norm); + } + + (axes_a, axes_b) } - } + }; - // Validate + // Validate bounds for &ax in &axes_a { assert!( (0..nda).contains(&ax), @@ -1338,10 +1360,10 @@ where ); } - // Shape checks + // Shape checks on contracted axes for (ia, ib) in axes_a.iter().zip(&axes_b) { - let da = a.shape()[*ia as usize]; - let db = b.shape()[*ib as usize]; + let da = ashape[*ia as usize]; + let db = bshape[*ib as usize]; assert_eq!( da, db, "tensordot: shape mismatch along contraction axis: a[{}]={} vs b[{}]={}", @@ -1349,25 +1371,72 @@ where ); } - // Determine non-contracted axes - let notin_a: Vec = (0..nda as usize) - .filter(|k| !axes_a.iter().any(|&ax| ax as usize == *k)) - .collect(); - let notin_b: Vec = (0..ndb as usize) - .filter(|k| !axes_b.iter().any(|&ax| ax as usize == *k)) - .collect(); - - // Reorder axes - let mut newaxes_a = notin_a.clone(); - newaxes_a.extend(axes_a.iter().map(|&x| x as usize)); - let mut newaxes_b = axes_b.iter().map(|&x| x as usize).collect::>(); - newaxes_b.extend(notin_b.iter().copied()); - - // Matrix shapes - let m = notin_a.iter().fold(1, |p, &ax| p * a.shape()[ax]); - let k = axes_a.iter().fold(1, |p, &ax| p * a.shape()[ax as usize]); - let n = notin_b.iter().fold(1, |p, &ax| p * b.shape()[ax]); + // Membership maps for contracted axes (O(ndim) setup, O(1) lookup) + let mut is_contracted_a = vec![false; nda as usize]; + let mut is_contracted_b = vec![false; ndb as usize]; + for &ax in &axes_a { + is_contracted_a[ax as usize] = true; + } + for &ax in &axes_b { + is_contracted_b[ax as usize] = true; + } + + let contracted_a = axes_a.len(); + let contracted_b = axes_b.len(); + debug_assert_eq!(contracted_a, contracted_b); + let free_a = nda as usize - contracted_a; + let free_b = ndb as usize - contracted_b; + + // Permutation axes for a: [non-contracted..., contracted...] + let mut newaxes_a = Vec::with_capacity(nda as usize); + for i in 0..nda as usize { + if !is_contracted_a[i] { + newaxes_a.push(i); + } + } + for &ax in &axes_a { + newaxes_a.push(ax as usize); + } + + // non-contracted axes for b (indices) + let mut notin_b = Vec::with_capacity(free_b); + for i in 0..ndb as usize { + if !is_contracted_b[i] { + notin_b.push(i); + } + } + + // Permutation axes for b: [contracted..., non-contracted...] + let mut newaxes_b = Vec::with_capacity(ndb as usize); + for &ax in &axes_b { + newaxes_b.push(ax as usize); + } + newaxes_b.extend(¬in_b); + + // Output shape: a(non-contracted) ⧺ b(non-contracted) + let mut out_shape = Vec::with_capacity(free_a + free_b); + for i in 0..nda as usize { + if !is_contracted_a[i] { + out_shape.push(ashape[i]); + } + } + for &ax in ¬in_b { + out_shape.push(bshape[ax]); + } + + // Matrix dims + let m = newaxes_a[..free_a] + .iter() + .fold(1, |p, &ax| p * ashape[ax]); + let k = axes_a + .iter() + .fold(1, |p, &ax| p * ashape[ax as usize]); + let n = notin_b + .iter() + .fold(1, |p, &ax| p * bshape[ax]); + + // Permute + standard layout (keep temporaries named to satisfy lifetimes) let a_dyn = a.view().into_dimensionality::().unwrap(); let b_dyn = b.view().into_dimensionality::().unwrap(); @@ -1377,6 +1446,7 @@ where let b_perm = b_dyn.permuted_axes(IxDyn(&newaxes_b)); let b_std = b_perm.as_standard_layout(); + // Reshape to 2D, multiply, and reshape back let a2 = a_std .into_shape_with_order(Ix2(m, k)) .expect("reshaping a to 2D"); @@ -1386,9 +1456,6 @@ where let c2 = a2.dot(&b2); - let mut out_shape: Vec = notin_a.iter().map(|&ax| a.shape()[ax]).collect(); - out_shape.extend(notin_b.iter().map(|&ax| b.shape()[ax])); - c2.into_shape_with_order(IxDyn(&out_shape)).unwrap() } @@ -1404,7 +1471,7 @@ where type Output = ArrayD; #[track_caller] - fn tensordot(&self, rhs: &ArrayBase, axes: AxisSpec) -> Self::Output + fn tensordot(&self, rhs: &ArrayBase, axes: &AxisSpec) -> Self::Output { tensordot_impl::(self, rhs, axes) } @@ -1421,7 +1488,7 @@ where type Output = ArrayD; #[track_caller] - fn tensordot(&self, rhs: &ArrayRef, axes: AxisSpec) -> Self::Output + fn tensordot(&self, rhs: &ArrayRef, axes: &AxisSpec) -> Self::Output { let rhs_view: ArrayBase, D2> = rhs.view(); tensordot_impl::, D1, D2>(self, &rhs_view, axes) @@ -1439,7 +1506,7 @@ where type Output = ArrayD; #[track_caller] - fn tensordot(&self, rhs: &ArrayBase, axes: AxisSpec) -> Self::Output + fn tensordot(&self, rhs: &ArrayBase, axes: &AxisSpec) -> Self::Output { let self_view: ArrayBase, D1> = self.view(); tensordot_impl::, S, D1, D2>(&self_view, rhs, axes) @@ -1459,7 +1526,7 @@ mod tensordot_tests let a = ArrayD::from_shape_vec(IxDyn(&[3, 4, 5]), (0..60).collect::>()).unwrap(); let b = ArrayD::from_shape_vec(IxDyn(&[4, 3, 2]), (0..24).collect::>()).unwrap(); - let c: ArrayD = tensordot(&a, &b, AxisSpec::Pair(vec![1, 0], vec![0, 1])); + let c: ArrayD = tensordot(&a, &b, &AxisSpec::Pair(vec![1, 0], vec![0, 1])); // Expected shape: [5, 2] assert_eq!( @@ -1494,7 +1561,7 @@ mod tensordot_tests let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![10, 20, 30, 40]).unwrap(); // Contract over 2 axes - let c: ArrayD = tensordot(&a, &b, AxisSpec::Num(2)); + let c: ArrayD = tensordot(&a, &b, &AxisSpec::Num(2)); assert_eq!( c.shape(),