The article links: orzsiyuan.com/archives/Al…
The fast Walsh transform is used to solve the convolution of polynomial bit operations. Its calculation process and
Similar.
An overview of the
First let’s recall polynomial convolution:
in”Algorithm Notes” Fast Fourier transform FFTIn we will have polynomials 和
This is converted to point value expression, and then converted back to coefficient expression.
Then consider the convolution form as follows:
Among themIndicates bitwise or, bitwise and, bitwise xOR respectively.
So you can’t use itSolved, we need to introduce new methods:Fast Walsh transform.
This will be classified belowClass convolution solution. Among themdefinePart of it is a constructed transformation method,provePart will use mathematical induction to prove some lemma or theorem,codePart is the realization of the convolution.
Symbolic representation
First of all, we assume that all the polynomials mentioned below are of lengthA nonnegative integer power of. For convenience, we define the following symbols and their meanings, which will not be described here.
Particular attention should be paid to: the polynomial defined here is xOR, and, or are convolution form! Not the corresponding coefficient bit operation!
Or convolution
define
prove
-
The sum of two polynomials
The transformation is equal to
And:
-
when
时
-
when
时
-
-
Two polynomials or convolution
The transformation is equal to
The product:
-
when
时
-
when
时
& = (\text{FWT}(A_0) \times \text{FWT}(B_0), (\text{FWT}(A_0) + \text{FWT}(A_1)) \times (\text{FWT}(B_0) + \text{FWT}(B_1))) \ & = (\text{FWT}(A_0) \times \text{FWT}(B_0), \text{FWT}(A_0 + A_1) \times \text{FWT}(B_0 + B_1)) \ & = (\text{FWT}(A_0), \text{FWT}(A_0 + A_1)) \times (\text{FWT}(B_0), \text{FWT}(B_0 + B_1)) \ & = \text{FWT}(A)\times \text{FWT}(B) \end{align*}
-
code
void FWTor(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j + m], a[i + j]);elsesub(a[i + j + m], a[i + j]); }}}Copy the code
With convolution
define
prove
-
The sum of two polynomials
The transformation is equal to
And:
-
when
时
-
when
时
& = (\text{FWT}(A_0) + \text{FWT}(A_1), \text{FWT}(A_1)) + (\text{FWT}(B_0) + \text{FWT}(B_1), \text{FWT}(B_1)) \ & = \text{FWT}(A) + \text{FWT}(B) \end{align*}
-
when
时
-
when
时
& \qquad \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = ((\text{FWT}(A_0) + \text{FWT}(A_1)) \times (\text{FWT}(B_0) + \text{FWT}(B_1)), \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = (\text{FWT}(A_0 + A_1) \times \text{FWT}(B_0 + B_1), \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = (\text{FWT}(A_0 + A_1), \text{FWT}(A_1)) \times (\text{FWT}(B_0 + B_1), \text{FWT}(B_1)) \ & = \text{FWT}(A)\times \text{FWT}(B) \end{align*}
-
code
void FWTand(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j], a[i + j + m]);elsesub(a[i + j], a[i + j + m]); }}}Copy the code
Exclusive or convolution
define
prove
-
The sum of two polynomials
The transformation is equal to
And:
-
when
时
-
when
时
-
-
Xor convolution of two polynomials
The transformation is equal to
The product:
-
when
时
-
when
时
-
Inverse transformation
code
void FWTxor(std: :vector<int> &a, bool rev) {
int n = a.size(), inv2 = (P + 1) > >1;
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
int x = a[i + j], y = a[i + j + m];
if(! rev) { a[i + j] = (x + y) % P; a[i + j + m] = (x - y + P) % P; }else {
a[i + j] = 1LL * (x + y) * inv2 % P;
a[i + j + m] = 1LL * (x - y + P) * inv2 % P; }}}}Copy the code
The complete code
#include <cstdio>
#include <algorithm>
#include <vector>
const int P = 998244353;
void add(int &x, int y) {
(x += y) >= P && (x -= P);
}
void sub(int &x, int y) {
(x -= y) < 0 && (x += P);
}
struct FWT {
int extend(int n) {
int N = 1;
for (; N < n; N <<= 1);
return N;
}
void FWTor(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j + m], a[i + j]);elsesub(a[i + j + m], a[i + j]); }}}void FWTand(std: :vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if(! rev) add(a[i + j], a[i + j + m]);elsesub(a[i + j], a[i + j + m]); }}}void FWTxor(std: :vector<int> &a, bool rev) {
int n = a.size(), inv2 = (P + 1) > >1;
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
int x = a[i + j], y = a[i + j + m];
if(! rev) { a[i + j] = (x + y) % P; a[i + j + m] = (x - y + P) % P; }else {
a[i + j] = 1LL * (x + y) * inv2 % P;
a[i + j + m] = 1LL * (x - y + P) * inv2 % P; }}}}std: :vector<int> Or(std: :vector<int> a1, std: :vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTor(a1, false);
a2.resize(N), FWTor(a2, false);
std: :vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTor(A, true);
return A;
}
std: :vector<int> And(std: :vector<int> a1, std: :vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTand(a1, false);
a2.resize(N), FWTand(a2, false);
std: :vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTand(A, true);
return A;
}
std: :vector<int> Xor(std: :vector<int> a1, std: :vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTxor(a1, false);
a2.resize(N), FWTxor(a2, false);
std: :vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTxor(A, true);
return A;
}
} fwt;
int main(a) {
int n;
scanf("%d", &n);
std: :vector<int> a1(n), a2(n);
for (int i = 0; i < n; i++) scanf("%d", &a1[i]);
for (int i = 0; i < n; i++) scanf("%d", &a2[i]);
std: :vector<int> A;
A = fwt.Or(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
A = fwt.And(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
A = fwt.Xor(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
return 0;
}
Copy the code
Problem sets
- Luogu 4717 [Template] Fast Walsh transform
- “BZOJ 4589” Hard Nim
- HihoCoder 1230 The Celebration of Rabbits
- “HDU 5909” Tree Cutting
reference
- FWT Fast Walsh Transform learning Notes – YYB