Skip to content

Commit 7e8c1b5

Browse files
committed
only use slices
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent a07666b commit 7e8c1b5

File tree

1 file changed

+22
-72
lines changed

1 file changed

+22
-72
lines changed

vortex-compute/src/filter/buffer_mut.rs

Lines changed: 22 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44
use core::ptr;
55

66
use vortex_buffer::BufferMut;
7-
use vortex_mask::{Mask, MaskIter};
7+
use vortex_mask::Mask;
88

99
use crate::filter::Filter;
1010

1111
// TODO(connor): Implement `Filter` for more combinations of `Filter`
1212

13-
// TODO(connor): Figure out if this threshold makes sense for in-place filter.
14-
// This is modeled after the constant with the equivalent name in arrow-rs.
15-
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
16-
1713
impl<T: Copy> Filter for &mut BufferMut<T> {
1814
type Output = ();
1915

@@ -27,54 +23,32 @@ impl<T: Copy> Filter for &mut BufferMut<T> {
2723
match selection_mask {
2824
Mask::AllTrue(_) => {}
2925
Mask::AllFalse(_) => self.clear(),
30-
// SAFETY: We checked above that the selection mask has the same length as the buffer.
31-
Mask::Values(values) => unsafe {
32-
match values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
33-
MaskIter::Indices(indices) => filter_indices_in_place(self, indices),
34-
MaskIter::Slices(slices) => filter_slices_in_place(self, slices),
35-
}
36-
},
26+
Mask::Values(values) => {
27+
// We choose to _always_ use slices here because iterating over indices will have
28+
// strictly more loop iterations than slices, and the overhead over batched
29+
// `ptr::copy(len)` is not worth it.
30+
let slices = values.slices();
31+
32+
// SAFETY: We checked above that the selection mask has the same length as the
33+
// buffer.
34+
let new_len = unsafe { filter_slices_in_place(self, slices) };
35+
36+
// Truncate the buffer to the new length.
37+
self.truncate(new_len);
38+
}
3739
}
3840
}
3941
}
4042

41-
/// Filters a buffer in-place using indices to determine which values to keep.
42-
///
43-
/// # Safety
44-
///
45-
/// The indices must be in the range of the `buffer`.
46-
unsafe fn filter_indices_in_place<T: Copy>(buffer: &mut BufferMut<T>, indices: &[usize]) {
47-
let slice = buffer.as_mut_slice();
48-
let mut write_idx = 0;
49-
50-
// For each index in the selection, copy the element to the current write position.
51-
for &read_idx in indices {
52-
// Note that we could add an if statement here that checks `if read_idx != write_idx` and
53-
// use `ptr::copy_nonoverlapping`, but it's probably better to just avoid the branch
54-
// misprediction.
55-
56-
// SAFETY: Both indices are within bounds since indices come from a valid mask.
57-
unsafe {
58-
ptr::copy(
59-
slice.as_ptr().add(read_idx),
60-
slice.as_mut_ptr().add(write_idx),
61-
1,
62-
)
63-
};
64-
write_idx += 1;
65-
}
66-
67-
// Truncate the buffer to the new length.
68-
buffer.truncate(write_idx);
69-
}
70-
7143
/// Filters a buffer in-place using slice ranges to determine which values to keep.
7244
///
45+
/// Returns the new length of the buffer.
46+
///
7347
/// # Safety
7448
///
7549
/// The slice ranges must be in the range of the `buffer`.
76-
unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut BufferMut<T>, slices: &[(usize, usize)]) {
77-
let slice = buffer.as_mut_slice();
50+
#[must_use = "The caller should set the new length of the buffer"]
51+
unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut [T], slices: &[(usize, usize)]) -> usize {
7852
let mut write_pos = 0;
7953

8054
// For each range in the selection, copy all of the elements to the current write position.
@@ -84,21 +58,19 @@ unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut BufferMut<T>, slices: &[(
8458

8559
let len = end - start;
8660

87-
// SAFETY: The ranges are within bounds since they come from a valid mask for the
88-
// buffer.
61+
// SAFETY: The safety contract enforces that all ranges are within bounds.
8962
unsafe {
9063
ptr::copy(
91-
slice.as_ptr().add(start),
92-
slice.as_mut_ptr().add(write_pos),
64+
buffer.as_ptr().add(start),
65+
buffer.as_mut_ptr().add(write_pos),
9366
len,
9467
)
9568
};
9669

9770
write_pos += len;
9871
}
9972

100-
// Truncate the buffer to the new length.
101-
buffer.truncate(write_pos);
73+
write_pos
10274
}
10375

10476
#[cfg(test)]
@@ -213,26 +185,4 @@ mod tests {
213185

214186
buf.filter(&mask);
215187
}
216-
217-
#[test]
218-
fn test_filter_indices_direct() {
219-
let mut buf = buffer_mut![100u32, 200, 300, 400, 500];
220-
unsafe { filter_indices_in_place(&mut buf, &[1, 3, 4]) };
221-
assert_eq!(buf.as_slice(), &[200, 400, 500]);
222-
}
223-
224-
#[test]
225-
fn test_filter_slices_direct() {
226-
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6, 7];
227-
unsafe { filter_slices_in_place(&mut buf, &[(1, 3), (5, 7)]) };
228-
assert_eq!(buf.as_slice(), &[2, 3, 6, 7]);
229-
}
230-
231-
#[test]
232-
fn test_filter_overlapping_slices() {
233-
// Test that overlapping regions are handled correctly.
234-
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6, 7, 8];
235-
unsafe { filter_slices_in_place(&mut buf, &[(2, 6)]) };
236-
assert_eq!(buf.as_slice(), &[3, 4, 5, 6]);
237-
}
238188
}

0 commit comments

Comments
 (0)