@@ -162,21 +162,24 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
162162 ( r0, m0)
163163}
164164
165- /// Returns $\sum_{i = 0}^{n - 1} \lfloor \frac{a \times i + b}{m} \rfloor$.
165+ /// Returns
166+ ///
167+ /// $$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor.$$
168+ ///
169+ /// It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed.
166170///
167171/// # Constraints
168172///
169- /// - $0 \leq n \leq 10^9$
170- /// - $1 \leq m \leq 10^9$
171- /// - $0 \leq a, b \leq m$
173+ /// - $0 \leq n \lt 2^{32}$
174+ /// - $1 \leq m \lt 2^{32}$
172175///
173176/// # Panics
174177///
175178/// Panics if the above constraints are not satisfied and overflow or division by zero occurred.
176179///
177180/// # Complexity
178181///
179- /// - $O(\log(n + m + a + b) )$
182+ /// - $O(\log{(m+a)} )$
180183///
181184/// # Example
182185///
@@ -185,25 +188,25 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
185188///
186189/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
187190/// ```
188- pub fn floor_sum ( n : i64 , m : i64 , mut a : i64 , mut b : i64 ) -> i64 {
189- let mut ans = 0 ;
190- if a >= m {
191- ans += ( n - 1 ) * n * ( a / m) / 2 ;
192- a %= m;
193- }
194- if b >= m {
195- ans += n * ( b / m) ;
196- b %= m;
191+ #[ allow( clippy:: many_single_char_names) ]
192+ pub fn floor_sum ( n : i64 , m : i64 , a : i64 , b : i64 ) -> i64 {
193+ use std:: num:: Wrapping as W ;
194+ assert ! ( ( 0 ..1i64 << 32 ) . contains( & n) ) ;
195+ assert ! ( ( 1 ..1i64 << 32 ) . contains( & m) ) ;
196+ let mut ans = W ( 0_u64 ) ;
197+ let ( wn, wm, mut wa, mut wb) = ( W ( n as u64 ) , W ( m as u64 ) , W ( a as u64 ) , W ( b as u64 ) ) ;
198+ if a < 0 {
199+ let a2 = W ( internal_math:: safe_mod ( a, m) as u64 ) ;
200+ ans -= wn * ( wn - W ( 1 ) ) / W ( 2 ) * ( ( a2 - wa) / wm) ;
201+ wa = a2;
197202 }
198-
199- let y_max = ( a * n + b) / m;
200- let x_max = y_max * m - b;
201- if y_max == 0 {
202- return ans;
203+ if b < 0 {
204+ let b2 = W ( internal_math:: safe_mod ( b, m) as u64 ) ;
205+ ans -= wn * ( ( b2 - wb) / wm) ;
206+ wb = b2;
203207 }
204- ans += ( n - ( x_max + a - 1 ) / a) * y_max;
205- ans += floor_sum ( y_max, a, m, ( a - x_max % a) % a) ;
206- ans
208+ let ret = ans + internal_math:: floor_sum_unsigned ( wn, wm, wa, wb) ;
209+ ret. 0 as i64
207210}
208211
209212#[ cfg( test) ]
@@ -306,5 +309,24 @@ mod tests {
306309 499_999_999_500_000_000
307310 ) ;
308311 assert_eq ! ( floor_sum( 332955 , 5590132 , 2231 , 999423 ) , 22014575 ) ;
312+ for n in 0 ..20 {
313+ for m in 1 ..20 {
314+ for a in -20 ..20 {
315+ for b in -20 ..20 {
316+ assert_eq ! ( floor_sum( n, m, a, b) , floor_sum_naive( n, m, a, b) ) ;
317+ }
318+ }
319+ }
320+ }
321+ }
322+
323+ #[ allow( clippy:: many_single_char_names) ]
324+ fn floor_sum_naive ( n : i64 , m : i64 , a : i64 , b : i64 ) -> i64 {
325+ let mut ans = 0 ;
326+ for i in 0 ..n {
327+ let z = a * i + b;
328+ ans += ( z - internal_math:: safe_mod ( z, m) ) / m;
329+ }
330+ ans
309331 }
310332}
0 commit comments