The article links: orzsiyuan.com/archives/Al…
The fast Walsh transform is used to solve the convolution of polynomial bit operations. Its calculation process andSimilar.
An overview of the
First let’s recall polynomial convolution:
in”Algorithm Notes” Fast Fourier transform FFTIn we will have polynomials 和 This is converted to point value expression, and then converted back to coefficient expression.
Then consider the convolution form as follows:
Among themIndicates bitwise or, bitwise and, bitwise xOR respectively.
So you can’t use itSolved, we need to introduce new methods:Fast Walsh transform.
This will be classified belowClass convolution solution. Among themdefinePart of it is a constructed transformation method,provePart will use mathematical induction to prove some lemma or theorem,codePart is the realization of the convolution.
Symbolic representation
First of all, we assume that all the polynomials mentioned below are of lengthA nonnegative integer power of. For convenience, we define the following symbols and their meanings, which will not be described here.
Particular attention should be paid to: the polynomial defined here is xOR, and, or are convolution form! Not the corresponding coefficient bit operation!
Or convolution
define
prove
-
The sum of two polynomialsThe transformation is equal toAnd:
-
when 时
-
when 时
-
-
Two polynomials or convolutionThe transformation is equal toThe product:
-
when 时
-
when 时
& = (\text{FWT}(A_0) \times \text{FWT}(B_0), (\text{FWT}(A_0) + \text{FWT}(A_1)) \times (\text{FWT}(B_0) + \text{FWT}(B_1))) \ & = (\text{FWT}(A_0) \times \text{FWT}(B_0), \text{FWT}(A_0 + A_1) \times \text{FWT}(B_0 + B_1)) \ & = (\text{FWT}(A_0), \text{FWT}(A_0 + A_1)) \times (\text{FWT}(B_0), \text{FWT}(B_0 + B_1)) \ & = \text{FWT}(A)\times \text{FWT}(B) \end{align*}
-
code
void FWTor(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j + m], a[i + j]);elsesub(a[i + j + m], a[i + j]); }}}Copy the code
With convolution
define
prove
-
The sum of two polynomialsThe transformation is equal toAnd:
-
when 时
-
when 时
& = (\text{FWT}(A_0) + \text{FWT}(A_1), \text{FWT}(A_1)) + (\text{FWT}(B_0) + \text{FWT}(B_1), \text{FWT}(B_1)) \ & = \text{FWT}(A) + \text{FWT}(B) \end{align*}
-
when 时
-
when 时
& \qquad \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = ((\text{FWT}(A_0) + \text{FWT}(A_1)) \times (\text{FWT}(B_0) + \text{FWT}(B_1)), \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = (\text{FWT}(A_0 + A_1) \times \text{FWT}(B_0 + B_1), \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = (\text{FWT}(A_0 + A_1), \text{FWT}(A_1)) \times (\text{FWT}(B_0 + B_1), \text{FWT}(B_1)) \ & = \text{FWT}(A)\times \text{FWT}(B) \end{align*}
-
code
void FWTand(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j], a[i + j + m]);elsesub(a[i + j], a[i + j + m]); }}}Copy the code
Exclusive or convolution
define
prove
-
The sum of two polynomialsThe transformation is equal toAnd:
-
when 时
-
when 时
-
-
Xor convolution of two polynomialsThe transformation is equal toThe product:
-
when 时
-
when 时
-
Inverse transformation
code
void FWTxor(std: :vector<int> &a, bool rev) {
int n = a.size(), inv2 = (P + 1) > >1;
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
int x = a[i + j], y = a[i + j + m];
if(! rev) { a[i + j] = (x + y) % P; a[i + j + m] = (x - y + P) % P; }else {
a[i + j] = 1LL * (x + y) * inv2 % P;
a[i + j + m] = 1LL * (x - y + P) * inv2 % P; }}}}Copy the code
The complete code
#include <cstdio>
#include <algorithm>
#include <vector>
const int P = 998244353;
void add(int &x, int y) {
(x += y) >= P && (x -= P);
}
void sub(int &x, int y) {
(x -= y) < 0 && (x += P);
}
struct FWT {
int extend(int n) {
int N = 1;
for (; N < n; N <<= 1);
return N;
}
void FWTor(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j + m], a[i + j]);elsesub(a[i + j + m], a[i + j]); }}}void FWTand(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j], a[i + j + m]);elsesub(a[i + j], a[i + j + m]); }}}void FWTxor(std: :vector<int> &a, bool rev) {
int n = a.size(), inv2 = (P + 1) > >1;
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
int x = a[i + j], y = a[i + j + m];
if(! rev) { a[i + j] = (x + y) % P; a[i + j + m] = (x - y + P) % P; }else {
a[i + j] = 1LL * (x + y) * inv2 % P;
a[i + j + m] = 1LL * (x - y + P) * inv2 % P; }}}}std: :vector<int> Or(std: :vector<int> a1, std: :vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTor(a1, false);
a2.resize(N), FWTor(a2, false);
std: :vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTor(A, true);
return A;
}
std: :vector<int> And(std: :vector<int> a1, std: :vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTand(a1, false);
a2.resize(N), FWTand(a2, false);
std: :vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTand(A, true);
return A;
}
std: :vector<int> Xor(std: :vector<int> a1, std: :vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTxor(a1, false);
a2.resize(N), FWTxor(a2, false);
std: :vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTxor(A, true);
return A;
}
} fwt;
int main(a) {
int n;
scanf("%d", &n);
std: :vector<int> a1(n), a2(n);
for (int i = 0; i < n; i++) scanf("%d", &a1[i]);
for (int i = 0; i < n; i++) scanf("%d", &a2[i]);
std: :vector<int> A;
A = fwt.Or(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
A = fwt.And(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
A = fwt.Xor(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
return 0;
}
Copy the code
Problem sets
- Luogu 4717 [Template] Fast Walsh transform
- “BZOJ 4589” Hard Nim
- HihoCoder 1230 The Celebration of Rabbits
- “HDU 5909” Tree Cutting
reference
- FWT Fast Walsh Transform learning Notes – YYB