The article links: orzsiyuan.com/archives/Al…

The fast Walsh transform is used to solve the convolution of polynomial bit operations. Its calculation process andSimilar.


An overview of the

First let’s recall polynomial convolution:


in”Algorithm Notes” Fast Fourier transform FFTIn we will have polynomialsThis is converted to point value expression, and then converted back to coefficient expression.

Then consider the convolution form as follows:


Among themIndicates bitwise or, bitwise and, bitwise xOR respectively.

So you can’t use itSolved, we need to introduce new methods:Fast Walsh transform.

This will be classified belowClass convolution solution. Among themdefinePart of it is a constructed transformation method,provePart will use mathematical induction to prove some lemma or theorem,codePart is the realization of the convolution.


Symbolic representation

First of all, we assume that all the polynomials mentioned below are of lengthA nonnegative integer power of. For convenience, we define the following symbols and their meanings, which will not be described here.


Particular attention should be paid to: the polynomial defined here is xOR, and, or are convolution form! Not the corresponding coefficient bit operation!


Or convolution

define


prove

  1. The sum of two polynomialsThe transformation is equal toAnd:

    • when


    • when


  2. Two polynomials or convolutionThe transformation is equal toThe product:

    • when


    • when


      & = (\text{FWT}(A_0) \times \text{FWT}(B_0), (\text{FWT}(A_0) + \text{FWT}(A_1)) \times (\text{FWT}(B_0) + \text{FWT}(B_1))) \ & = (\text{FWT}(A_0) \times \text{FWT}(B_0), \text{FWT}(A_0 + A_1) \times \text{FWT}(B_0 + B_1)) \ & = (\text{FWT}(A_0), \text{FWT}(A_0 + A_1)) \times (\text{FWT}(B_0), \text{FWT}(B_0 + B_1)) \ & = \text{FWT}(A)\times \text{FWT}(B) \end{align*}



code

void FWTor(std: :vector<int> &a, bool rev) {
	int n = a.size();
	for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
		for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
			if(! rev) add(a[i + j + m], a[i + j]);elsesub(a[i + j + m], a[i + j]); }}}Copy the code

With convolution

define


prove

  1. The sum of two polynomialsThe transformation is equal toAnd:

    • when


    • when


      & = (\text{FWT}(A_0) + \text{FWT}(A_1), \text{FWT}(A_1)) + (\text{FWT}(B_0) + \text{FWT}(B_1), \text{FWT}(B_1)) \ & = \text{FWT}(A) + \text{FWT}(B) \end{align*}


    • when


    • when


      & \qquad \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = ((\text{FWT}(A_0) + \text{FWT}(A_1)) \times (\text{FWT}(B_0) + \text{FWT}(B_1)), \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = (\text{FWT}(A_0 + A_1) \times \text{FWT}(B_0 + B_1), \text{FWT}(A_1) \times \text{FWT}(B_1)) \ & = (\text{FWT}(A_0 + A_1), \text{FWT}(A_1)) \times (\text{FWT}(B_0 + B_1), \text{FWT}(B_1)) \ & = \text{FWT}(A)\times \text{FWT}(B) \end{align*}



code

void FWTand(std: :vector<int> &a, bool rev) {
	int n = a.size();
	for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
		for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
			if(! rev) add(a[i + j], a[i + j + m]);elsesub(a[i + j], a[i + j + m]); }}}Copy the code

Exclusive or convolution

define


prove

  1. The sum of two polynomialsThe transformation is equal toAnd:

    • when


    • when


  2. Xor convolution of two polynomialsThe transformation is equal toThe product:

    • when


    • when


Inverse transformation


code

void FWTxor(std: :vector<int> &a, bool rev) {
	int n = a.size(), inv2 = (P + 1) > >1;
	for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
		for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
			int x = a[i + j], y = a[i + j + m];
			if(! rev) { a[i + j] = (x + y) % P; a[i + j + m] = (x - y + P) % P; }else {
				a[i + j] = 1LL * (x + y) * inv2 % P;
				a[i + j + m] = 1LL * (x - y + P) * inv2 % P; }}}}Copy the code

The complete code

#include <cstdio>
#include <algorithm>
#include <vector>

const int P = 998244353;

void add(int &x, int y) {
	(x += y) >= P && (x -= P);
}
void sub(int &x, int y) {
	(x -= y) < 0 && (x += P);
}
struct FWT {
	int extend(int n) {
		int N = 1;
		for (; N < n; N <<= 1);
		return N;
	}
	void FWTor(std: :vector<int> &a, bool rev) {
		int n = a.size();
		for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
			for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
				if(! rev) add(a[i + j + m], a[i + j]);elsesub(a[i + j + m], a[i + j]); }}}void FWTand(std: :vector<int> &a, bool rev) {
		int n = a.size();
		for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
			for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
				if(! rev) add(a[i + j], a[i + j + m]);elsesub(a[i + j], a[i + j + m]); }}}void FWTxor(std: :vector<int> &a, bool rev) {
		int n = a.size(), inv2 = (P + 1) > >1;
		for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
			for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
				int x = a[i + j], y = a[i + j + m];
				if(! rev) { a[i + j] = (x + y) % P; a[i + j + m] = (x - y + P) % P; }else {
					a[i + j] = 1LL * (x + y) * inv2 % P;
					a[i + j + m] = 1LL * (x - y + P) * inv2 % P; }}}}std: :vector<int> Or(std: :vector<int> a1, std: :vector<int> a2) {
		int n = std::max(a1.size(), a2.size()), N = extend(n);
		a1.resize(N), FWTor(a1, false);
		a2.resize(N), FWTor(a2, false);
		std: :vector<int> A(N);
		for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
		FWTor(A, true);
		return A;
	}
	std: :vector<int> And(std: :vector<int> a1, std: :vector<int> a2) {
		int n = std::max(a1.size(), a2.size()), N = extend(n);
		a1.resize(N), FWTand(a1, false);
		a2.resize(N), FWTand(a2, false);
		std: :vector<int> A(N);
		for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
		FWTand(A, true);
		return A;
	}
	std: :vector<int> Xor(std: :vector<int> a1, std: :vector<int> a2) {
		int n = std::max(a1.size(), a2.size()), N = extend(n);
		a1.resize(N), FWTxor(a1, false);
		a2.resize(N), FWTxor(a2, false);
		std: :vector<int> A(N);
		for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
		FWTxor(A, true);
		return A;
	}
} fwt;

int main(a) {
	int n;
	scanf("%d", &n);
	std: :vector<int> a1(n), a2(n);
	for (int i = 0; i < n; i++) scanf("%d", &a1[i]);
	for (int i = 0; i < n; i++) scanf("%d", &a2[i]);
	std: :vector<int> A;
	A = fwt.Or(a1, a2);
	for (int i = 0; i < n; i++) {
		printf("%d%c", A[i], " \n"[i == n - 1]);
	}
	A = fwt.And(a1, a2);
	for (int i = 0; i < n; i++) {
		printf("%d%c", A[i], " \n"[i == n - 1]);
	}
	A = fwt.Xor(a1, a2);
	for (int i = 0; i < n; i++) {
		printf("%d%c", A[i], " \n"[i == n - 1]);
	}
	return 0;
}
Copy the code

Problem sets

  • Luogu 4717 [Template] Fast Walsh transform
  • “BZOJ 4589” Hard Nim
  • HihoCoder 1230 The Celebration of Rabbits
  • “HDU 5909” Tree Cutting

reference

  • FWT Fast Walsh Transform learning Notes – YYB