44use core:: ptr;
55
66use vortex_buffer:: BufferMut ;
7- use vortex_mask:: { Mask , MaskIter } ;
7+ use vortex_mask:: Mask ;
88
99use crate :: filter:: Filter ;
1010
1111// TODO(connor): Implement `Filter` for more combinations of `Filter`
1212
13- // TODO(connor): Figure out if this threshold makes sense for in-place filter.
14- // This is modeled after the constant with the equivalent name in arrow-rs.
15- const FILTER_SLICES_SELECTIVITY_THRESHOLD : f64 = 0.8 ;
16-
1713impl < T : Copy > Filter for & mut BufferMut < T > {
1814 type Output = ( ) ;
1915
@@ -27,54 +23,32 @@ impl<T: Copy> Filter for &mut BufferMut<T> {
2723 match selection_mask {
2824 Mask :: AllTrue ( _) => { }
2925 Mask :: AllFalse ( _) => self . clear ( ) ,
30- // SAFETY: We checked above that the selection mask has the same length as the buffer.
31- Mask :: Values ( values) => unsafe {
32- match values. threshold_iter ( FILTER_SLICES_SELECTIVITY_THRESHOLD ) {
33- MaskIter :: Indices ( indices) => filter_indices_in_place ( self , indices) ,
34- MaskIter :: Slices ( slices) => filter_slices_in_place ( self , slices) ,
35- }
36- } ,
26+ Mask :: Values ( values) => {
27+ // We choose to _always_ use slices here because iterating over indices will have
28+ // strictly more loop iterations than slices, and the overhead over batched
29+ // `ptr::copy(len)` is not worth it.
30+ let slices = values. slices ( ) ;
31+
32+ // SAFETY: We checked above that the selection mask has the same length as the
33+ // buffer.
34+ let new_len = unsafe { filter_slices_in_place ( self , slices) } ;
35+
36+ // Truncate the buffer to the new length.
37+ self . truncate ( new_len) ;
38+ }
3739 }
3840 }
3941}
4042
41- /// Filters a buffer in-place using indices to determine which values to keep.
42- ///
43- /// # Safety
44- ///
45- /// The indices must be in the range of the `buffer`.
46- unsafe fn filter_indices_in_place < T : Copy > ( buffer : & mut BufferMut < T > , indices : & [ usize ] ) {
47- let slice = buffer. as_mut_slice ( ) ;
48- let mut write_idx = 0 ;
49-
50- // For each index in the selection, copy the element to the current write position.
51- for & read_idx in indices {
52- // Note that we could add an if statement here that checks `if read_idx != write_idx` and
53- // use `ptr::copy_nonoverlapping`, but it's probably better to just avoid the branch
54- // misprediction.
55-
56- // SAFETY: Both indices are within bounds since indices come from a valid mask.
57- unsafe {
58- ptr:: copy (
59- slice. as_ptr ( ) . add ( read_idx) ,
60- slice. as_mut_ptr ( ) . add ( write_idx) ,
61- 1 ,
62- )
63- } ;
64- write_idx += 1 ;
65- }
66-
67- // Truncate the buffer to the new length.
68- buffer. truncate ( write_idx) ;
69- }
70-
7143/// Filters a buffer in-place using slice ranges to determine which values to keep.
7244///
45+ /// Returns the new length of the buffer.
46+ ///
7347/// # Safety
7448///
7549/// The slice ranges must be in the range of the `buffer`.
76- unsafe fn filter_slices_in_place < T : Copy > ( buffer : & mut BufferMut < T > , slices : & [ ( usize , usize ) ] ) {
77- let slice = buffer . as_mut_slice ( ) ;
50+ # [ must_use = "The caller should set the new length of the buffer" ]
51+ unsafe fn filter_slices_in_place < T : Copy > ( buffer : & mut [ T ] , slices : & [ ( usize , usize ) ] ) -> usize {
7852 let mut write_pos = 0 ;
7953
8054 // For each range in the selection, copy all of the elements to the current write position.
@@ -84,21 +58,19 @@ unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut BufferMut<T>, slices: &[(
8458
8559 let len = end - start;
8660
87- // SAFETY: The ranges are within bounds since they come from a valid mask for the
88- // buffer.
61+ // SAFETY: The safety contract enforces that all ranges are within bounds.
8962 unsafe {
9063 ptr:: copy (
91- slice . as_ptr ( ) . add ( start) ,
92- slice . as_mut_ptr ( ) . add ( write_pos) ,
64+ buffer . as_ptr ( ) . add ( start) ,
65+ buffer . as_mut_ptr ( ) . add ( write_pos) ,
9366 len,
9467 )
9568 } ;
9669
9770 write_pos += len;
9871 }
9972
100- // Truncate the buffer to the new length.
101- buffer. truncate ( write_pos) ;
73+ write_pos
10274}
10375
10476#[ cfg( test) ]
@@ -213,26 +185,4 @@ mod tests {
213185
214186 buf. filter ( & mask) ;
215187 }
216-
217- #[ test]
218- fn test_filter_indices_direct ( ) {
219- let mut buf = buffer_mut ! [ 100u32 , 200 , 300 , 400 , 500 ] ;
220- unsafe { filter_indices_in_place ( & mut buf, & [ 1 , 3 , 4 ] ) } ;
221- assert_eq ! ( buf. as_slice( ) , & [ 200 , 400 , 500 ] ) ;
222- }
223-
224- #[ test]
225- fn test_filter_slices_direct ( ) {
226- let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 , 6 , 7 ] ;
227- unsafe { filter_slices_in_place ( & mut buf, & [ ( 1 , 3 ) , ( 5 , 7 ) ] ) } ;
228- assert_eq ! ( buf. as_slice( ) , & [ 2 , 3 , 6 , 7 ] ) ;
229- }
230-
231- #[ test]
232- fn test_filter_overlapping_slices ( ) {
233- // Test that overlapping regions are handled correctly.
234- let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ;
235- unsafe { filter_slices_in_place ( & mut buf, & [ ( 2 , 6 ) ] ) } ;
236- assert_eq ! ( buf. as_slice( ) , & [ 3 , 4 , 5 , 6 ] ) ;
237- }
238188}
0 commit comments