Skip to content

Commit 011c346

Browse files
committed
Simplify Dirichlet a bit
1 parent d936404 commit 011c346

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

cp-algo/number_theory/dirichlet.hpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ namespace cp_algo::math {
1515
auto operator <=>(const interval&) const = default;
1616
};
1717

18-
// callback(k, prefix) such that:
19-
// (F * G)[k] = prefix + (F[k] - F[k-1]) * G[1] + (G[k] - G[k-1]) * F[1]
20-
// Uses H as a buffer for (F * G)[k], then overrides with callback results
18+
// callback(k) when:
19+
// (F * G)[k] = H[k] + (F[k] - F[k-1]) * G[1] + (G[k] - G[k-1]) * F[1]
20+
// Return the value to be saved in H[k]
2121
enum exec_mode { standard, reverse };
2222
template<exec_mode mode = standard>
2323
void exec_on_blocks(int64_t n, auto &H, auto const& F, auto const& G, auto &&callback) {
@@ -30,26 +30,23 @@ namespace cp_algo::math {
3030
auto call = [&](interval x, interval y, interval z) {
3131
auto Fx = F[x.hi] - F[x.lo - 1];
3232
auto Fy = F[y.hi] - F[y.lo - 1];
33-
decltype(Fx) Gx, Gy, t;
33+
decltype(Fx) Gx, Gy;
3434
if constexpr (mode == standard) {
3535
Gy = G[y.hi] - G[y.lo - 1];
3636
Gx = G[x.hi] - G[x.lo - 1];
3737
} else {
3838
Gy = G[y.lo - 1] - G[y.hi];
3939
Gx = G[x.lo - 1] - G[x.hi];
4040
}
41-
if(x == y) [[unlikely]] {
42-
t = Fx * Gy;
43-
} else {
44-
t = Fx * Gy + Fy * Gx;
41+
auto t = Fx * Gy;
42+
if(x != y) [[likely]] {
43+
t += Fy * Gx;
4544
}
4645
H[z.lo] += t;
47-
if (z.hi < num_floors) {
46+
if (z.hi < num_floors) [[likely]] {
4847
H[z.hi + 1] -= t;
4948
}
5049
};
51-
52-
auto prefix = F[1] * G[1];
5350
for (int k = 2; k <= num_floors; ++k) {
5451
if(k > rt_n) {
5552
int z = num_floors - k + 1;
@@ -61,8 +58,7 @@ namespace cp_algo::math {
6158
}
6259
}
6360

64-
H[k] = callback(k, prefix += H[k]);
65-
prefix += (F[k] - F[k-1]) * G[1] + (G[k] - G[k-1]) * F[1];
61+
H[k] = callback(k);
6662

6763
if(k <= rt_n) {
6864
int x = k;
@@ -82,9 +78,10 @@ namespace cp_algo::math {
8278
auto m = size(F);
8379
std::decay_t<decltype(F)> H(m);
8480
H[1] = F[1] * G[1];
85-
exec_on_blocks(n, H, F, G, [&](auto k, auto prefix) {
86-
return prefix + (F[k] - F[k-1]) * G[1] + (G[k] - G[k-1]) * F[1];
81+
exec_on_blocks(n, H, F, G, [&](auto k) {
82+
return H[k] + (F[k] - F[k-1]) * G[1] + (G[k] - G[k-1]) * F[1];
8783
});
84+
partial_sum(begin(H), end(H), begin(H));
8885
return H;
8986
}
9087

@@ -94,7 +91,7 @@ namespace cp_algo::math {
9491
H[0] -= H[0];
9592
adjacent_difference(begin(H), end(H), begin(H));
9693
H[1] *= Gi;
97-
exec_on_blocks<reverse>(n, H, H, G, [&](auto k, auto) {
94+
exec_on_blocks<reverse>(n, H, H, G, [&](auto k) {
9895
return (Gi * (H[k] - (G[k] - G[k-1]) * H[1])) + H[k-1];
9996
});
10097
}

0 commit comments

Comments
 (0)