快速傅里叶变换⚓︎
FFT⚓︎
快速傅立叶变换(\text{Fast Fourier Transform,FFT})是一种高效实现离散傅里叶变换(\text{DFT})的算法。
\text{FFT} 算法的基本思想是分治。就 \text{DFT} 来说,它分治地来求当 x = \omega_n^k 的时候 f(x) 的值。
\text{FFT} 算法的分治思想体现在将多项式分为奇次项和偶次项处理,分别用奇偶次项系数建立新的函数:f(x) = Even(x^2) + x * Odd(x^2)。
其中Even(x) = \sum a_{2i}x^i,Odd(x) = \sum a_{2i+1}x^i。
利用偶数次单位根的性质 \omega^i_n = -\omega^{i + n/2}_n,并注意 Even(x^2) 和 Odd(x^2) 只依赖于 x^2,因此在复平面上 \omega^i_n 和 \omega^{i+n/2}_n 对应的 Even(x^2) 和 Odd(x^2) 的值相同。以下推导中,令 n 为变换长度,k 为当前求值的下标。
因此求出了 Even(\omega_{n/2}^k) 和 Odd(\omega_{n/2}^k) 后,就可以同时求出 f(\omega_n^k) 和 f(\omega_n^{k+n/2})。
分治 \text{DFT} 能处理的多项式长度只能是 2^m(m \in \mathbf{N}^ \ast ),否则在分治的时候左右不一样长,右边取不到系数。所以要在第一次 \text{DFT} 之前就把序列向上补成长度为 2^m(高次系数补 0)、最高项次数为 2^m-1 的多项式。
在代入值的时候,因为要代入n个不同值:\omega_n^0,\omega_n^1,\omega_n^2,\cdots, \omega_n^{n-1} (n=2^m) 一共 2^m 个不同值。
注意:
- 在调用 \text{FFT} 前,需要将多项式的系数转化为复数,同时要将系数调整为 2^m 长度的多项式。
- 在求快速傅里叶变换的逆变换时,注意角度要乘以 -1,这在代码中通过
inv参数来控制。 - 求出快速傅里叶变换的逆变换后,需要正则化,将结果除以 n。
- 多项式乘法h(x) = f(x) \times g(x),可以通过 h(x) = \text{FFT}^{-1}(\text{FFT}(f(x)) \times \text{FFT}(g(x))) 来实现。在最后的结果中,需要将系数四舍五入取整。
w_k 的计算方式
\text{FFT}:
\text{Inverse FFT}:
递推关系:
【模板】多项式乘法(FFT)
给定两个多项式 f(x) 和 g(x),求它们的乘积多项式 h(x) = f(x) \times g(x)。
Keypoints
一般来说,\text{FFT} 的实现分为两种:递归实现和优化实现。递归实现的代码较为简单,优化实现则在此基础上进行了位逆序置换的优化。
#include <cmath>
#include <complex>
#include <iostream>
#include <numbers>
#include <vector>
using namespace std;
void fft(vector<complex<double>> &f, int inv) {
int n = f.size(); // n must be a power of 2
if (n == 1) { return; }
vector<complex<double>> even(n / 2), odd(n / 2);
for (int i = 0; i < n / 2; i++) {
even[i] = f[i * 2];
odd[i] = f[i * 2 + 1];
}
fft(even, inv);
fft(odd, inv);
double theta = 2 * numbers::pi / n;
complex<double> wk(1), w(cos(theta), inv * sin(theta));
for (int k = 0; k < n / 2; k++) {
f[k] = even[k] + wk * odd[k];
f[k + n / 2] = even[k] - wk * odd[k];
wk *= w;
}
}
int main() {
int n, m;
cin >> n >> m; // n, m 为f(x)和g(x)的最高次
int len = pow(2, ceil(log2(n + m + 1))); // 调整结果的长度为2的幂次
vector<complex<double>> f(len), g(len), h(len);
for (int i = 0; i <= n; i++) { cin >> f[i]; }
for (int i = 0; i <= m; i++) { cin >> g[i]; }
fft(f, 1);
fft(g, 1);
for (int i = 0; i < len; i++) { h[i] = f[i] * g[i]; } // 计算每个点的乘积
fft(h, -1); // 求h(x)的逆FFT
// 逆变换正则化, 四舍五入取整
for (int i = 0; i < n + m + 1; i++) { cout << round(h[i].real() / len) << " "; }
return 0;
}
位逆序置换⚓︎
在 FFT 中,位逆序置换是一个重要的步骤。考虑递归过程,在每次递归会两两分组。
以8项多项式为例,模拟拆分的过程:
- 初始序列为 \{x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7\}
- 一次二分之后 \{x_0, x_2, x_4, x_6\},\{x_1, x_3, x_5, x_7 \}
- 两次二分之后 \{x_0,x_4\} \{x_2, x_6\},\{x_1, x_5\},\{x_3, x_7 \}
- 三次二分之后 \{x_0\}\{x_4\}\{x_2\}\{x_6\}\{x_1\}\{x_5\}\{x_3\}\{x_7 \}
规律:其实就是原来的序列每个数用二进制表示,然后把二进制翻转,就是最终位置的下标。
在 O(n\log n) 的时间复杂度内,可以将序列进行位逆序置换。
位逆序置换代码实现
位逆序置换优化
在cpp中也可以直接使用bitset来翻转二进制位,时间复杂度与位数有关。
该实现的时间复杂度为 O(n\log n)。
该实现的时间复杂度为 O(n)。
void bit_reverse(vector<complex<double>> &f) {
int n = f.size();
vector<int> reverse(n);
for (int i = 0; i < n; ++i) {
reverse[i] = reverse[i >> 1] >> 1;
if ((i & 1) != 0) { reverse[i] |= n >> 1; } // 如果最后一位是 1,则翻转成 n/2
}
for (int i = 0; i < n; ++i) { // 保证每对数只翻转一次
if (i < reverse[i]) { swap(f[i], f[reverse[i]]); }
}
}
蝶形操作⚓︎
在 \text{FFT} 中,每次递归都会两两分组,然后进行蝶形操作。
蝶形操作是一种两两合并的操作,每次合并两个数,然后乘上一个单位根。在蝶形操作中,每个数都会和另一个数进行一次乘法和一次加法。
【模板】多项式乘法(FFT)
#include <cmath>
#include <complex>
#include <iostream>
#include <numbers>
#include <utility>
#include <vector>
using namespace std;
void bit_reverse(vector<complex<double>> &f) {
int n = f.size();
vector<int> reverse(n);
for (int i = 0; i < n; ++i) {
reverse[i] = reverse[i >> 1] >> 1;
if ((i & 1) != 0) { reverse[i] |= n >> 1; } // 如果最后一位是 1,则翻转成 n/2
}
for (int i = 0; i < n; ++i) { // 保证每对数只翻转一次
if (i < reverse[i]) { swap(f[i], f[reverse[i]]); }
}
}
void fft(vector<complex<double>> &f, int inv) {
int n = f.size();
bit_reverse(f);
for (int len = 2; len <= n; len *= 2) { // 递归对应的长度
double theta = 2 * numbers::pi / len * inv;
complex<double> w(cos(theta), sin(theta)); // 当前单位根
for (int i = 0; i < n; i += len) {
complex<double> wi(1);
for (int j = 0; j < len / 2; j++) {
complex<double> even = f[i + j], odd = wi * f[i + j + len / 2];
f[i + j] = even + odd;
f[i + j + len / 2] = even - odd;
wi *= w;
}
}
}
if (inv == -1) { // 逆变换正则化
for (int i = 0; i < n; i++) { f[i] /= n; }
}
}
int main() {
int n, m;
cin >> n >> m;
int len = pow(2, ceil(log2(n + m + 1)));
vector<complex<double>> f(len), g(len), h(len);
for (int i = 0; i <= n; i++) { cin >> f[i]; }
for (int i = 0; i <= m; i++) { cin >> g[i]; }
fft(f, 1);
fft(g, 1);
for (int i = 0; i < len; i++) { h[i] = f[i] * g[i]; }
fft(h, -1);
for (int i = 0; i < n + m + 1; i++) { cout << round(h[i].real()) << " "; }
return 0;
}
Tip
并行扫描算法的思路与蝶形变换类似
应用⚓︎
- 多项式乘法 h(x) = f(x) \times g(x)
- 多项式求逆 g(x) = f(x)^{-1}
f(x) \times g(x) = 1,g(x) = f(x)^{-1},g(x) 为 f(x) 的逆元
Thief in a Shop
n 种商品,每种价格为 a_1, a_2, \dots, a_n,每种商品数量无限。恰好带走 k 件商品(同一种可以多件),问可能的总价格有哪些。
Hint
定义多项式生成函数:
- 每次选一件商品,对应取一次 F(x)
- 选 k 件商品 → 求 k 次幂:F(x)^k
- x^s 的系数表示总价 s 是否可达,所以答案就是 F(x)^k 中所有系数不为 0 的项的指数
F(x)^k 可以通过 \text{FFT} 快速卷积计算。由于系数只关心是否为 0,所以在每次卷积后将系数大于 0 的项置为 1 即可。
另外 k 件物品最大价格为 k \times \max(a_i),所以结果的最高次幂不会超过这个值。可以每次卷积后将多项式截断到这个长度以节省时间和空间。
#include <algorithm>
#include <cmath>
#include <complex>
#include <iostream>
#include <numbers>
#include <utility>
#include <vector>
using namespace std;
using cd = complex<double>;
void bit_reverse(vector<complex<double>> &f) {
int n = f.size();
vector<int> reverse(n);
for (int i = 0; i < n; ++i) {
reverse[i] = reverse[i >> 1] >> 1;
if ((i & 1) != 0) { reverse[i] |= n >> 1; } // 如果最后一位是 1,则翻转成 n/2
}
for (int i = 0; i < n; ++i) { // 保证每对数只翻转一次
if (i < reverse[i]) { swap(f[i], f[reverse[i]]); }
}
}
void fft(vector<complex<double>> &f, int inv) {
int n = f.size();
bit_reverse(f);
for (int len = 2; len <= n; len *= 2) { // 递归对应的长度
double theta = 2 * numbers::pi / len * inv;
complex<double> w(cos(theta), sin(theta)); // 当前单位根
for (int i = 0; i < n; i += len) {
complex<double> wi(1);
for (int j = 0; j < len / 2; j++) {
complex<double> even = f[i + j], odd = wi * f[i + j + len / 2];
f[i + j] = even + odd;
f[i + j + len / 2] = even - odd;
wi *= w;
}
}
}
if (inv == -1) { // 逆变换正则化
for (int i = 0; i < n; i++) { f[i] /= n; }
}
}
// 布尔卷积(动态截断到 max_len)
vector<int> multiply(const vector<int> &a, const vector<int> &b, int max_len) {
int n = 1;
while (n < a.size() + b.size()) { n <<= 1; } // 找到大于等于两多项式和的最小 2 的幂
vector<cd> fa(n), fb(n);
for (int i = 0; i < a.size(); i++) { fa[i] = a[i]; }
for (int i = 0; i < b.size(); i++) { fb[i] = b[i]; }
fft(fa, 1);
fft(fb, 1);
for (int i = 0; i < n; i++) { fa[i] *= fb[i]; } // 点乘
fft(fa, -1); // 逆变换
vector<int> res(min(n, max_len), 0); // 截断到 max_len
// 只关心总价是否可达, 大于 0 则视为可达
for (int i = 0; i < res.size(); i++) { res[i] = (fa[i].real() > 0.5 ? 1 : 0); }
return res;
}
// 多项式快速幂
vector<int> poly_pow(vector<int> base, int k, int max_sum) {
vector<int> res = {1};
while (k > 0) {
if ((k & 1) != 0) { res = multiply(res, base, max_sum + 1); }
base = multiply(base, base, max_sum + 1);
k >>= 1;
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, k;
cin >> n >> k;
vector<int> a(n);
int max_a = 0;
for (int i = 0; i < n; i++) {
cin >> a[i];
max_a = max(max_a, a[i]);
}
int max_sum = k * max_a; // 最大可能总价
// 构造初始多项式 F(x)
vector<int> base(max_a + 1, 0);
for (int i = 0; i < n; i++) { base[a[i]] = 1; }
// 快速幂 F(x)^k
vector<int> ans = poly_pow(base, k, max_sum);
// 输出所有可能总价
for (int i = 0; i <= max_sum; i++) {
if (ans[i] != 0) { cout << i << " "; }
}
cout << "\n";
return 0;
}
NTT⚓︎
数论快速傅里叶变换(\text{Number Theoretic Transform},\text{NTT})是一种高效实现 \text{DFT} 的算法。\text{NTT} 算法的基本思想是将 \text{DFT} 的复数域转化为模数域。
\text{NTT} 通过将离散傅立叶变换化为 F={\mathbb {Z}/p},整数模质数 p。这是一个有限域,只要 n 可除 p-1,就存在本原 n 次方根,所以有 p=\xi n+1。
具体来说,对于质数 p=qn+1 (n=2^m),原根 g 满足 g^{qn} \equiv 1 \pmod p, 将 g_n=g^q\pmod p 看做 \omega_n 的等价,则其满足相似的性质,比如 g_n^n \equiv 1 \pmod p, g_n^{n/2} \equiv -1 \pmod p。