Skip to content

Commit 5ac8d69

Browse files
committed
Make the entire crate compatible with the array reference type
1 parent 7997e20 commit 5ac8d69

File tree

16 files changed

+190
-234
lines changed

16 files changed

+190
-234
lines changed

Cargo.lock

Lines changed: 70 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"]
2020
categories = ["data-structures", "science"]
2121

2222
[dependencies]
23-
ndarray = "0.16.0"
23+
ndarray = "0.17.1"
2424
noisy_float = "0.2.0"
2525
num-integer = "0.1"
2626
num-traits = "0.2"
@@ -29,10 +29,10 @@ itertools = { version = "0.13", default-features = false }
2929
indexmap = "2.4"
3030

3131
[dev-dependencies]
32-
ndarray = { version = "0.16.1", features = ["approx"] }
32+
ndarray = { version = "0.17.1", features = ["approx"] }
3333
criterion = "0.5.1"
3434
quickcheck = { version = "0.9.2", default-features = false }
35-
ndarray-rand = "0.15.0"
35+
ndarray-rand = "0.16.0"
3636
approx = "0.5"
3737
quickcheck_macros = "1.0.0"
3838
num-bigint = "0.4.0"

benches/deviation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn sq_l2_dist(c: &mut Criterion) {
1212
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
1313
for len in &lens {
1414
group.bench_with_input(format!("{}", len), len, |b, &len| {
15-
let data = Array::random(len, Uniform::new(0.0, 1.0));
16-
let data2 = Array::random(len, Uniform::new(0.0, 1.0));
15+
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
16+
let data2 = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
1717

1818
b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap()))
1919
});

benches/summary_statistics.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn weighted_std(c: &mut Criterion) {
1212
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
1313
for len in &lens {
1414
group.bench_with_input(format!("{}", len), len, |b, &len| {
15-
let data = Array::random(len, Uniform::new(0.0, 1.0));
16-
let mut weights = Array::random(len, Uniform::new(0.0, 1.0));
15+
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
16+
let mut weights = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
1717
weights /= weights.sum();
1818
b.iter_batched(
1919
|| data.clone(),

src/correlation.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use crate::errors::EmptyInput;
22
use ndarray::prelude::*;
3-
use ndarray::Data;
43
use num_traits::{Float, FromPrimitive};
54

6-
/// Extension trait for `ArrayBase` providing functions
5+
/// Extension trait for `ndarray` providing functions
76
/// to compute different correlation measures.
8-
pub trait CorrelationExt<A, S>
9-
where
10-
S: Data<Elem = A>,
11-
{
7+
pub trait CorrelationExt<A> {
128
/// Return the covariance matrix `C` for a 2-dimensional
139
/// array of observations `M`.
1410
///
@@ -125,10 +121,7 @@ where
125121
private_decl! {}
126122
}
127123

128-
impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
129-
where
130-
S: Data<Elem = A>,
131-
{
124+
impl<A: 'static> CorrelationExt<A> for ArrayRef2<A> {
132125
fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
133126
where
134127
A: Float + FromPrimitive,
@@ -147,7 +140,7 @@ where
147140
let mean = self.mean_axis(observation_axis);
148141
match mean {
149142
Some(mean) => {
150-
let denoised = self - &mean.insert_axis(observation_axis);
143+
let denoised = self - mean.insert_axis(observation_axis);
151144
let covariance = denoised.dot(&denoised.t());
152145
Ok(covariance.mapv_into(|x| x / dof))
153146
}
@@ -208,7 +201,7 @@ mod cov_tests {
208201
let n_observations = 4;
209202
let a = Array::random(
210203
(n_random_variables, n_observations),
211-
Uniform::new(-bound.abs(), bound.abs()),
204+
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
212205
);
213206
let covariance = a.cov(1.).unwrap();
214207
abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
@@ -219,7 +212,10 @@ mod cov_tests {
219212
fn test_invalid_ddof() {
220213
let n_random_variables = 3;
221214
let n_observations = 4;
222-
let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
215+
let a = Array::random(
216+
(n_random_variables, n_observations),
217+
Uniform::new(0., 10.).unwrap(),
218+
);
223219
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
224220
let _ = a.cov(invalid_ddof);
225221
}
@@ -299,7 +295,7 @@ mod pearson_correlation_tests {
299295
let n_observations = 4;
300296
let a = Array::random(
301297
(n_random_variables, n_observations),
302-
Uniform::new(-bound.abs(), bound.abs()),
298+
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
303299
);
304300
let pearson_correlation = a.pearson_correlation().unwrap();
305301
abs_diff_eq!(

0 commit comments

Comments
 (0)