@@ -29,20 +29,37 @@ namespace cp_algo::math::fft {
2929
3030 dft (auto const & a, size_t n): A(n), B(n) {
3131 init ();
32- base cur = factor;
33- base step = bpow (factor, n);
34- for (size_t i = 0 ; i < std::min (n, size (a)); i++) {
32+ base b2x32 = bpow (base (2 ), 32 );
33+ u64x4 cur = {
34+ (bpow (factor, 1 ) * b2x32).getr (),
35+ (bpow (factor, 2 ) * b2x32).getr (),
36+ (bpow (factor, 3 ) * b2x32).getr (),
37+ (bpow (factor, 4 ) * b2x32).getr ()
38+ };
39+ u64x4 step4 = u64x4{} + (bpow (factor, 4 ) * b2x32).getr ();
40+ u64x4 stepn = u64x4{} + (bpow (factor, n) * b2x32).getr ();
41+ for (size_t i = 0 ; i < std::min (n, size (a)); i += flen) {
3542 auto splt = [&](size_t i, auto mul) {
36- auto ai = i < size (a) ? (a[i] * mul).rem () : 0 ;
37- auto quo = ai / split;
38- auto rem = ai % split;
39- return std::pair{(ftype)rem, (ftype)quo};
43+ if (i >= size (a)) {
44+ return std::pair{vftype (), vftype ()};
45+ }
46+ u64x4 au = {
47+ i < size (a) ? a[i].getr () : 0 ,
48+ i + 1 < size (a) ? a[i + 1 ].getr () : 0 ,
49+ i + 2 < size (a) ? a[i + 2 ].getr () : 0 ,
50+ i + 3 < size (a) ? a[i + 3 ].getr () : 0
51+ };
52+ au = montgomery_mul (au, mul, mod, imod);
53+ au = au >= base::mod () ? au - base::mod () : au;
54+ auto ai = i64x4 (au);
55+ ai = ai >= base::mod () / 2 ? ai - base::mod () : ai;
56+ return std::pair{to_double (ai % split), to_double (ai / split)};
4057 };
4158 auto [rai, qai] = splt (i, cur);
42- auto [rani, qani] = splt (n + i, cur * step );
43- A.set (i, point (rai, rani) );
44- B.set (i, point (qai, qani) );
45- cur *= factor ;
59+ auto [rani, qani] = splt (n + i, montgomery_mul ( cur, stepn, mod, imod) );
60+ A.at (i) = vpoint (rai, rani);
61+ B.at (i) = vpoint (qai, qani);
62+ cur = montgomery_mul (cur, step4, mod, imod) ;
4663 }
4764 checkpoint (" dft init" );
4865 if (n) {
0 commit comments