跳转至

快速傅里叶变换⚓︎

FFT⚓︎

快速傅立叶变换(\text{Fast Fourier Transform,FFT})是一种高效实现离散傅里叶变换(\text{DFT})的算法。

\text{FFT} 算法的基本思想是分治。就 \text{DFT} 来说,它分治地来求当 x = \omega_n^k 的时候 f(x) 的值。

\text{FFT} 算法的分治思想体现在将多项式分为奇次项和偶次项处理,分别用奇偶次项系数建立新的函数:f(x) = Even(x^2) + x * Odd(x^2)

其中Even(x) = \sum a_{2i}x^iOdd(x) = \sum a_{2i+1}x^i

利用偶数次单位根的性质 \omega^i_n = -\omega^{i + n/2}_n,并注意 Even(x^2)Odd(x^2) 只依赖于 x^2,因此在复平面上 \omega^i_n\omega^{i+n/2}_n 对应的 Even(x^2)Odd(x^2) 的值相同。以下推导中,令 n 为变换长度,k 为当前求值的下标。

\begin{aligned} f(\omega_n^k) &= Even((\omega_n^k)^2) + \omega_n^k \times Odd((\omega_n^k)^2) \\ &= Even(\omega_n^{2k}) + \omega_n^k \times Odd(\omega_n^{2k}) \\ &= Even(\omega_{n/2}^k) + \omega_n^k \times Odd(\omega_{n/2}^k) \\ f(\omega_n^{k+n/2}) &= Even(\omega_n^{2k+n}) + \omega_n^{k+n/2} \times Odd(\omega_n^{2k+n}) \\ &= Even(\omega_n^{2k}) - \omega_n^k \times Odd(\omega_n^{2k}) \\ &= Even(\omega_{n/2}^k) - \omega_n^k \times Odd(\omega_{n/2}^k) \end{aligned}

因此求出了 Even(\omega_{n/2}^k)Odd(\omega_{n/2}^k) 后,就可以同时求出 f(\omega_n^k)f(\omega_n^{k+n/2})

分治 \text{DFT} 能处理的多项式长度只能是 2^m(m \in \mathbf{N}^ \ast ),否则在分治的时候左右不一样长,右边取不到系数。所以要在第一次 \text{DFT} 之前就把序列向上补成长度为 2^m(高次系数补 0)、最高项次数为 2^m-1 的多项式。

在代入值的时候,因为要代入n个不同值:\omega_n^0,\omega_n^1,\omega_n^2,\cdots, \omega_n^{n-1} (n=2^m) 一共 2^m 个不同值。

注意:

  1. 在调用 \text{FFT} 前,需要将多项式的系数转化为复数,同时要将系数调整为 2^m 长度的多项式。
  2. 在求快速傅里叶变换的逆变换时,注意角度要乘以 -1,这在代码中通过 inv 参数来控制。
  3. 求出快速傅里叶变换的逆变换后,需要正则化,将结果除以 n
  4. 多项式乘法h(x) = f(x) \times g(x),可以通过 h(x) = \text{FFT}^{-1}(\text{FFT}(f(x)) \times \text{FFT}(g(x))) 来实现。在最后的结果中,需要将系数四舍五入取整。

w_k 的计算方式

\text{FFT}:

w_k = e^{\frac{2\pi}{n} k i} = e^{\theta k i} = cos(\theta k) + i * sin(\theta k)

\text{Inverse FFT}:

w_k = e^{- \frac{2\pi}{n} k i} =e^{- \theta k i} = cos(\theta k) - i * sin(\theta k)

递推关系:

w_{k+1} = e^{\theta (k+1) * i} = e^{\theta k * i} * e^{\theta * i} = w(k) * w
【模板】多项式乘法(FFT)

给定两个多项式 f(x)g(x),求它们的乘积多项式 h(x) = f(x) \times g(x)

Keypoints

一般来说,\text{FFT} 的实现分为两种:递归实现和优化实现。递归实现的代码较为简单,优化实现则在此基础上进行了位逆序置换的优化。

C++
#include <cmath>
#include <complex>
#include <iostream>
#include <numbers>
#include <vector>
using namespace std;

void fft(vector<complex<double>> &f, int inv) {
  int n = f.size();  // n must be a power of 2
  if (n == 1) { return; }
  vector<complex<double>> even(n / 2), odd(n / 2);
  for (int i = 0; i < n / 2; i++) {
    even[i] = f[i * 2];
    odd[i]  = f[i * 2 + 1];
  }
  fft(even, inv);
  fft(odd, inv);
  double theta = 2 * numbers::pi / n;
  complex<double> wk(1), w(cos(theta), inv * sin(theta));
  for (int k = 0; k < n / 2; k++) {
    f[k]          = even[k] + wk * odd[k];
    f[k + n / 2]  = even[k] - wk * odd[k];
    wk           *= w;
  }
}

int main() {
  int n, m;
  cin >> n >> m;                            // n, m 为f(x)和g(x)的最高次
  int len = pow(2, ceil(log2(n + m + 1)));  // 调整结果的长度为2的幂次
  vector<complex<double>> f(len), g(len), h(len);
  for (int i = 0; i <= n; i++) { cin >> f[i]; }
  for (int i = 0; i <= m; i++) { cin >> g[i]; }
  fft(f, 1);
  fft(g, 1);
  for (int i = 0; i < len; i++) { h[i] = f[i] * g[i]; }  // 计算每个点的乘积
  fft(h, -1);                                            // 求h(x)的逆FFT
  // 逆变换正则化, 四舍五入取整
  for (int i = 0; i < n + m + 1; i++) { cout << round(h[i].real() / len) << " "; }
  return 0;
}

位逆序置换⚓︎

在 FFT 中,位逆序置换是一个重要的步骤。考虑递归过程,在每次递归会两两分组。

8项多项式为例,模拟拆分的过程:

  1. 初始序列为 \{x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7\}
  2. 一次二分之后 \{x_0, x_2, x_4, x_6\},\{x_1, x_3, x_5, x_7 \}
  3. 两次二分之后 \{x_0,x_4\} \{x_2, x_6\},\{x_1, x_5\},\{x_3, x_7 \}
  4. 三次二分之后 \{x_0\}\{x_4\}\{x_2\}\{x_6\}\{x_1\}\{x_5\}\{x_3\}\{x_7 \}

规律:其实就是原来的序列每个数用二进制表示,然后把二进制翻转,就是最终位置的下标。

\begin{aligned} 0 &\to 000 \to 000 \to 0 \\ 1 &\to 001 \to 100 \to 4 \\ 2 &\to 010 \to 010 \to 2 \\ 3 &\to 011 \to 110 \to 6 \\ 4 &\to 100 \to 001 \to 1 \\ 5 &\to 101 \to 101 \to 5 \\ 6 &\to 110 \to 011 \to 3 \\ 7 &\to 111 \to 111 \to 7 \\ \end{aligned}

O(n\log n) 的时间复杂度内,可以将序列进行位逆序置换。

位逆序置换代码实现

位逆序置换优化

cpp中也可以直接使用bitset来翻转二进制位,时间复杂度与位数有关。

该实现的时间复杂度为 O(n\log n)

C++
void bit_reverse(vector<complex<double>> &f) {
  int n = f.size();
  // i = 0...01, j = 10...0, 在二进制下是翻转的
  for (int i = 1, j = n / 2; i < n - 1; i++) {
    if (i < j) { swap(f[i], f[j]); }  // 互换位置,i < j 避免重复交换
    int k = n / 2;  // 保持二进制翻转
    while (j >= k) {
      j -= k;
      k /= 2;
    }
    if (j < k) { j += k; }
  }
}

该实现的时间复杂度为 O(n)

C++
void bit_reverse(vector<complex<double>> &f) {
  int n = f.size();
  vector<int> reverse(n);
  for (int i = 0; i < n; ++i) {
    reverse[i] = reverse[i >> 1] >> 1;
    if ((i & 1) != 0) { reverse[i] |= n >> 1; }  // 如果最后一位是 1,则翻转成 n/2
  }
  for (int i = 0; i < n; ++i) {  // 保证每对数只翻转一次
    if (i < reverse[i]) { swap(f[i], f[reverse[i]]); }
  }
}

蝶形操作⚓︎

\text{FFT} 中,每次递归都会两两分组,然后进行蝶形操作。

蝶形操作是一种两两合并的操作,每次合并两个数,然后乘上一个单位根。在蝶形操作中,每个数都会和另一个数进行一次乘法和一次加法。

【模板】多项式乘法(FFT)
C++
#include <cmath>
#include <complex>
#include <iostream>
#include <numbers>
#include <utility>
#include <vector>
using namespace std;

void bit_reverse(vector<complex<double>> &f) {
  int n = f.size();
  vector<int> reverse(n);
  for (int i = 0; i < n; ++i) {
    reverse[i] = reverse[i >> 1] >> 1;
    if ((i & 1) != 0) { reverse[i] |= n >> 1; }  // 如果最后一位是 1,则翻转成 n/2
  }
  for (int i = 0; i < n; ++i) {  // 保证每对数只翻转一次
    if (i < reverse[i]) { swap(f[i], f[reverse[i]]); }
  }
}

void fft(vector<complex<double>> &f, int inv) {
  int n = f.size();
  bit_reverse(f);
  for (int len = 2; len <= n; len *= 2) {  // 递归对应的长度
    double theta = 2 * numbers::pi / len * inv;
    complex<double> w(cos(theta), sin(theta));  // 当前单位根
    for (int i = 0; i < n; i += len) {
      complex<double> wi(1);
      for (int j = 0; j < len / 2; j++) {
        complex<double> even = f[i + j], odd = wi * f[i + j + len / 2];
        f[i + j]            = even + odd;
        f[i + j + len / 2]  = even - odd;
        wi                 *= w;
      }
    }
  }
  if (inv == -1) {  // 逆变换正则化
    for (int i = 0; i < n; i++) { f[i] /= n; }
  }
}

int main() {
  int n, m;
  cin >> n >> m;
  int len = pow(2, ceil(log2(n + m + 1)));
  vector<complex<double>> f(len), g(len), h(len);
  for (int i = 0; i <= n; i++) { cin >> f[i]; }
  for (int i = 0; i <= m; i++) { cin >> g[i]; }
  fft(f, 1);
  fft(g, 1);
  for (int i = 0; i < len; i++) { h[i] = f[i] * g[i]; }
  fft(h, -1);
  for (int i = 0; i < n + m + 1; i++) { cout << round(h[i].real()) << " "; }
  return 0;
}

Tip

并行扫描算法的思路与蝶形变换类似

应用⚓︎

  1. 多项式乘法 h(x) = f(x) \times g(x)
  2. 多项式求逆 g(x) = f(x)^{-1}
    f(x) \times g(x) = 1g(x) = f(x)^{-1}g(x)f(x) 的逆元
Thief in a Shop

n 种商品,每种价格为 a_1, a_2, \dots, a_n,每种商品数量无限。恰好带走 k 件商品(同一种可以多件),问可能的总价格有哪些。

Hint

定义多项式生成函数:

F(x) = x^{a_1} + x^{a_2} + \dots + x^{a_n}
  • 每次选一件商品,对应取一次 F(x)
  • k 件商品 → 求 k 次幂:F(x)^k
  • x^s 的系数表示总价 s 是否可达,所以答案就是 F(x)^k 中所有系数不为 0 的项的指数

F(x)^k 可以通过 \text{FFT} 快速卷积计算。由于系数只关心是否为 0,所以在每次卷积后将系数大于 0 的项置为 1 即可。

另外 k 件物品最大价格为 k \times \max(a_i),所以结果的最高次幂不会超过这个值。可以每次卷积后将多项式截断到这个长度以节省时间和空间。

C++
#include <algorithm>
#include <cmath>
#include <complex>
#include <iostream>
#include <numbers>
#include <utility>
#include <vector>
using namespace std;

using cd = complex<double>;

void bit_reverse(vector<complex<double>> &f) {
  int n = f.size();
  vector<int> reverse(n);
  for (int i = 0; i < n; ++i) {
    reverse[i] = reverse[i >> 1] >> 1;
    if ((i & 1) != 0) { reverse[i] |= n >> 1; }  // 如果最后一位是 1,则翻转成 n/2
  }
  for (int i = 0; i < n; ++i) {  // 保证每对数只翻转一次
    if (i < reverse[i]) { swap(f[i], f[reverse[i]]); }
  }
}

void fft(vector<complex<double>> &f, int inv) {
  int n = f.size();
  bit_reverse(f);
  for (int len = 2; len <= n; len *= 2) {  // 递归对应的长度
    double theta = 2 * numbers::pi / len * inv;
    complex<double> w(cos(theta), sin(theta));  // 当前单位根
    for (int i = 0; i < n; i += len) {
      complex<double> wi(1);
      for (int j = 0; j < len / 2; j++) {
        complex<double> even = f[i + j], odd = wi * f[i + j + len / 2];
        f[i + j]            = even + odd;
        f[i + j + len / 2]  = even - odd;
        wi                 *= w;
      }
    }
  }
  if (inv == -1) {  // 逆变换正则化
    for (int i = 0; i < n; i++) { f[i] /= n; }
  }
}

// 布尔卷积(动态截断到 max_len)
vector<int> multiply(const vector<int> &a, const vector<int> &b, int max_len) {
  int n = 1;
  while (n < a.size() + b.size()) { n <<= 1; }  // 找到大于等于两多项式和的最小 2 的幂
  vector<cd> fa(n), fb(n);
  for (int i = 0; i < a.size(); i++) { fa[i] = a[i]; }
  for (int i = 0; i < b.size(); i++) { fb[i] = b[i]; }
  fft(fa, 1);
  fft(fb, 1);
  for (int i = 0; i < n; i++) { fa[i] *= fb[i]; }  // 点乘
  fft(fa, -1);                                     // 逆变换
  vector<int> res(min(n, max_len), 0);             // 截断到 max_len
  // 只关心总价是否可达, 大于 0 则视为可达
  for (int i = 0; i < res.size(); i++) { res[i] = (fa[i].real() > 0.5 ? 1 : 0); }
  return res;
}

// 多项式快速幂
vector<int> poly_pow(vector<int> base, int k, int max_sum) {
  vector<int> res = {1};
  while (k > 0) {
    if ((k & 1) != 0) { res = multiply(res, base, max_sum + 1); }
    base   = multiply(base, base, max_sum + 1);
    k    >>= 1;
  }
  return res;
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);

  int n, k;
  cin >> n >> k;
  vector<int> a(n);
  int max_a = 0;
  for (int i = 0; i < n; i++) {
    cin >> a[i];
    max_a = max(max_a, a[i]);
  }

  int max_sum = k * max_a;  // 最大可能总价

  // 构造初始多项式 F(x)
  vector<int> base(max_a + 1, 0);
  for (int i = 0; i < n; i++) { base[a[i]] = 1; }

  // 快速幂 F(x)^k
  vector<int> ans = poly_pow(base, k, max_sum);

  // 输出所有可能总价
  for (int i = 0; i <= max_sum; i++) {
    if (ans[i] != 0) { cout << i << " "; }
  }
  cout << "\n";

  return 0;
}

NTT⚓︎

数论快速傅里叶变换(\text{Number Theoretic Transform}\text{NTT})是一种高效实现 \text{DFT} 的算法。\text{NTT} 算法的基本思想是将 \text{DFT} 的复数域转化为模数域。

\text{NTT} 通过将离散傅立叶变换化为 F={\mathbb {Z}/p},整数模质数 p。这是一个有限域,只要 n 可除 p-1,就存在本原 n 次方根,所以有 p=\xi n+1

具体来说,对于质数 p=qn+1 (n=2^m),原根 g 满足 g^{qn} \equiv 1 \pmod p, 将 g_n=g^q\pmod p 看做 \omega_n 的等价,则其满足相似的性质,比如 g_n^n \equiv 1 \pmod p, g_n^{n/2} \equiv -1 \pmod p

评论