数位动态规划⚓︎
数位动态规划(\text{Digit DP})用于解决与数字的位数、数位相关的问题。其核心思想是通过对数字的每一位进行状态定义和转移,从而求解整个数字的问题。
数位是指把一个数字按照个、十、百、千等等一位一位地拆开,关注它每一位上的数字。如果拆的是十进制数,那么每一位数字都是 [0,9],其他进制可类比十进制。
数位 \text{DP} 问题一般具有这几个特征:
- 要求统计满足一定条件的数的数量,即最终目的为计数
- 这些条件经过转化后可以使用数位的思想去理解和判断
- 输入会提供一个数字区间(有时也只提供上界)来作为统计的限制
- 上界很大(比如 10^{18}),暴力枚举验证会超时
一般解决这类问题的方法是通过记忆化搜索,即通过递归的方式枚举每一位数字,同时记录当前枚举的位数、是否受限制、是否已经填了数字等状态。这样可以通过备忘录记录已经计算过的状态,避免重复计算。
- 备忘录 memo:记录已经计算过的状态,避免重复计算。初始化为-1,表示没有计算过
- DFS 函数:递归枚举每一位数字,同时记录当前枚举的位数、是否受限制、是否已经填了数字等状态。通过备忘录记录已经计算过的状态,避免重复计算。各个参数的含义:
- i:当前枚举的第几位。如果 i 等于数字的位数,表示已经枚举完了
- is\_limit:当前枚举数是否受限制。如果受限制表示当前枚举的数字不能超过上界。比如,如果上界是 1234,那么在枚举第一位的时候,只能枚举 [0,1],不能枚举 [2,9];如果第一位枚举了 1,那么第二位只能枚举 [0,2],不能枚举 [3,9]
- is\_num:当前是否已经填了数字,受前导 0 限制。如果前导 0 有影响,需要去掉该参数,或第一次调用时设置为 true
- 添加其他参数:如 pre 表示前一位填的数字,mask 表示已经使用的数字,diff 表示奇偶位数差异(可能为负数需要加上偏移量),mod 表示余数
备忘录 memo 的维度与可变参数个数有关,is\_limit 和 is\_num 无需记忆化,因为使用 memo 时这两个参数必定为 false 和 true。
在问题求解时,一般将 [l, r] 转换为 [0, r] 和 [0, l-1],然后分别求解,最后相减即可。
如果 l-1 也不适合(如 l=0 或非常大的数字),则直接求解 [0, l],然后相减并且对 l 进行特殊的 check 处理,判断是否满足条件。
模版(伪代码)
Text Only
digit_dp(s):
m = s.length()
memo = 视题目状态而定
dfs(i, is_limit, is_num, 其他状态):
if i == m:
return is_num ? 1 : 0
if !is_limit && is_num && 当前状态已被记忆化:
return memo[当前状态]
res = 0
if !is_num:
res += dfs(i + 1, false, false, 更新后的其他状态)
up = is_limit ? s[i] - '0' : 9
for d in [1 - is_num, up]:
if 满足题目限制:
res += dfs(i + 1, is_limit && d == up, true, 更新后的其他状态)
if !is_limit && is_num:
memo[当前状态] = res
return res
return dfs(0, true, false, 初始状态)
solve(low, high):
return digit_dp(high) - digit_dp(low - 1)
当 low = 0 或题目使用大整数区间时,不能直接套 low - 1,需要单独处理左端点。
数字计数
给定两个正整数 a 和 b,求在 [a,b] 中的所有整数中,每个数码(digit)各出现了多少次。
C++
#include <cstdint>
#include <functional>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
int64_t digit_dp(const string &s, int num) {
int m = s.length();
vector<vector<int64_t>> memo(m, vector<int64_t>(m, -1));
// count: 当前已经填了多少个 num
using dfs_type = function<int64_t(int64_t, bool, bool, int)>;
dfs_type dfs = [&](int i, bool is_limit, bool is_num, int count) -> int64_t {
if (i == m) { return is_num ? count : 0; }
if (!is_limit && is_num && memo[i][count] != -1) { return memo[i][count]; }
int64_t res = 0;
// 可以跳过当前数位
if (!is_num) { res = dfs(i + 1, false, false, count); }
int up = is_limit ? s[i] - '0' : 9;
// 枚举要填入的数字 d
for (int d = 1 - is_num; d <= up; ++d) {
res = res + dfs(i + 1, is_limit && d == up, true, count + (d == num));
}
// 记忆化结果
if (!is_limit && is_num) { memo[i][count] = res; }
return res;
};
return dfs(0, true, false, 0);
}
int main() {
string low;
string high;
cin >> low >> high;
vector<int64_t> count(10);
for (char ch : low) { count[ch - '0']++; }
for (int i = 0; i <= 9; ++i) { cout << digit_dp(high, i) - digit_dp(low, i) + count[i] << " "; }
}
MYQ10 - Mirror Number
求 [a,b] 中镜像回文的个数。镜像回文是指上下对称,左右对称的数字。显然:镜像回文由 0,1,8 构成。
C++
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
using VVVI = vector<vector<vector<int64_t>>>;
VVVI memo; // 记忆化数组
int64_t digit_dp(string &s) {
int64_t m = s.length();
s.push_back('0');
reverse(s.begin(), s.end());
vector<int64_t> fill(m + 1, 0);
// start表示非0的起始位置,那么从这里开始的长度是m-start,最多只用遍历一半
using dfs_type = function<int64_t(int64_t, bool, bool, int64_t)>;
dfs_type dfs = [&](int64_t i, bool is_limit, bool ok, int64_t start) -> int64_t {
if (i == 0) { return ok; } // 枚举完所有数位, ok表示是否对称
if (!is_limit && memo[i][start][ok] != -1) { return memo[i][start][ok]; }
int64_t res = 0;
int64_t up = is_limit ? s[i] - '0' : 9;
// 枚举要填入的数字 d
for (int64_t d = 0; d <= up; ++d) {
if (d != 0 && d != 1 && d != 8) { // 不能有非0,1,8
continue;
}
fill[i] = d;
int64_t new_ok = (ok && i <= start / 2) ? d == fill[start - i + 1] : ok; // 对称性
int64_t new_start = (start == i && d == 0) ? start - 1 : start; // 更新非0起始位置
res += dfs(i - 1, is_limit && d == up, new_ok, new_start);
}
if (!is_limit) { memo[i][start][ok] = res; }
return res;
};
return dfs(m, true, true, m);
}
int64_t check(const string &s) {
int64_t m = s.length();
int64_t res = 0;
for (int64_t i = 0; i < m; i++) {
if (s[i] != '0' && s[i] != '1' && s[i] != '8') { // 不能有非0,1,8
return 0;
}
if (s[i] != s[m - i - 1]) { // 对称
return 0;
}
}
return 1;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t t;
cin >> t;
size_t m = 0;
vector<string> low(t), high(t);
for (int64_t i = 0; i < t; i++) {
cin >> low[i] >> high[i];
m = max({m, low[i].length(), high[i].length()});
}
// 多个测试用例时,复用memo数组, 避免memo重新计算
memo.resize(m + 1, vector<vector<int64_t>>(m + 1, vector<int64_t>(2, -1)));
for (int i = 0; i < t; i++) {
int is_low = check(low[i]); // low本身是否符合要求
int64_t high_ans = digit_dp(high[i]);
int64_t low_ans = digit_dp(low[i]);
cout << high_ans - low_ans + is_low << "\n";
}
}