【蓝桥杯】进阶-状态压缩动态规划的典型题型深入分析


1. 树上DP概述

树形 DP,即在树上进行的 DP。由于树固有的递归性质,树形 DP 一般都是递归进行的。

大部分的树形 DP 都是线性的,并且由于树本身就是有序的,所以具有十分良好的性质,例如子结构性质等。

树形 DP 在算法竞赛中考察多样,但是简单的可以分为:

  1. 树上线性 DP。
  2. 换根 DP。

树上线性 DP 也分为很多种,不同的题有不同的考法,实际上,所有的 DP 都能在树上考,但是蓝桥比赛中,一般就几种考法:

  1. 树上决策,例如选最大值,最小值。
  2. 树上背包。
  3. 换根,换根 dp 是树上的一类特殊性质。

接下来,我们将通过几个问题来描述这三种问题的解法。

2. 树上决策问题

树上决策问题,往往是子节点向父节点转移时,只取最优的解,这一点与线性 DP 十分相似。

看一道例题:

2. 1 生命之树-真题

图片描述

这题看着挺玄乎,其实并没有那么复杂。

我们观察题目要求:给定一棵树,选出一个非空集合,使得对于任意两个元素 $a, b$,都存在一个序列 $a, v_1, …v_k, b$ 是这个集合里的元素,并且相邻两个点之间有一条边。

本来可以一句话说清楚的事情,但是偏偏要给出数学定义,所以要考察大家的归纳整理能力。

实际上,就是要在树中选出一个连通块即可,并且满足连通块的和值最大。

为什么呢?

我们观察一幅图,相信大家能理解了:

图片描述

绿色的代表我们选择的点集合。这些点是连通的,所以满足要求。

如果换成这个样子:

图片描述

这样就不满足题目要求了。

所以大家可以体会出来,题目的要求,其实就是找一个树上的连通块。

那么我们的问题就变成了在树上找最大的连通块了。

树形 DP,终究还是 DP,所以需要划分子问题。

我们常用的方法是,将子节点为根的子树,看成子问题,然后合并到当前根

将节点从深到浅(子树从小到大)的顺序作为 DP 的阶段,在 DP 的表示中,通常第一维代表节点的编号,后续维度按照问题进行设计。

首先我们需要解决一个问题,树上的连通块是什么?有什么性质可以利用。

答案是:树上的联通块也是树,他一定有根。所以我们要是找到这个根,或者枚举这个根,就可以找到答案。

我们设计的状态如下:

$dp_i$ 表示,对于节点为 $i$ 的子树,我们找到的以 $i$ 为根的连通块和值最大是 $dp_i$。

那么我们的转移的意义就是:对于 $i$ 来说,由于 $i$ 一定存在连通块中,所以,我们要找到他的儿子中,哪些是和 $i$ 连着的。

有一种贪心方案,对于 $i$ 的儿子 $v \in son(i)$ ,如果 $dp_v \ge 0$,我们就将他接入父亲即可。

所以,我们的转移方程就是: $$ dp_i = w_i + \sum _{dp_j \ge 0 & j \in son(i)} dp_j $$ 代码如下:

  • C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <vector>
using namespace std;

const int N = 1e5+100;
typedef long long ll;
vector<int> G[N];
int w[N];
ll dp[N], ans = -1e18;
int n;

void dfs(int u, int f) {
dp[u] = w[u];
for (int v : G[u]) {
if (v == f) continue;
dfs(v, u);
if (dp[v] > 0) {
dp[u] += dp[v];
}
}
ans = max(ans, dp[u]);
}

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> w[i];
}
int u, v;
for (int i = 1; i < n; ++i) {
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
cout << ans << endl;
return 0;
}
  • Java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import java.util.*;

public class Main {
private static final int N = (int) (1e5 + 100);
private static long[] dp;
private static int[] w;
private static List<List<Integer>> G;
private static long ans = Long.MIN_VALUE;
private static int n;

private static void dfs(int u, int f) {
dp[u] = w[u];
for (int v : G.get(u)) {
if (v == f) continue;
dfs(v, u);
if (dp[v] > 0) {
dp[u] += dp[v];
}
}
ans = Math.max(ans, dp[u]);
}

public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
n = scanner.nextInt();
w = new int[N];
G = new ArrayList<>();
for (int i = 0; i < N; i++) {
G.add(new ArrayList<>());
}
dp = new long[N];

for (int i = 0; i < n; i++) {
w[i] = scanner.nextInt();
}

for (int i = 0; i < n - 1; i++) {
int u = scanner.nextInt() - 1;
int v = scanner.nextInt() - 1;
G.get(u).add(v);
G.get(v).add(u);
}

dfs(0, -1);
System.out.println(ans);
scanner.close();
}
}
  • Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import sys
sys.setrecursionlimit(100000)

n = int(input())
aList = [0] + [int(i) for i in input().split()]
tree = [[] for i in range(n+1)]
ans = 0
dp = [0 for i in range(n+1)]

for i in range(n-1):
m, n = map(int, input().split())
tree[m].append(n)
tree[n].append(m)

def dfs(u, f):
global ans
dp[u] = aList[u]
for i in tree[u]:
if i != f:
dp[i] = dfs(i, u)
if dp[i] > 0:
dp[u] += dp[i]
ans = max(ans, dp[u])
return dp[u]

dfs(1, 0)
print(ans)

3. 树上背包问题

树上背包问题,本质上还是背包,可以看成在树上进行的背包。

每次转移都是在父亲与儿子之间进行了一次经典背包转移。

3.1 小明的背包6

图片描述

这个是典型的依赖背包问题。

并且依赖关系构成了一棵树。

我们看样例:

1
2
3
4
5
6
7
6 15
3 4 0
2 3 1
2 5 1
3 5 1
4 8 2
3 9 2

图片描述

依赖关系如上图所示:上图的含义是如果只有购买了 $1$ 号物品,才能购买 $2, 3, 4$ 号物品。

记住,我们的目标是划分子问题,也就是说,只要保证了一个子问题的划分是正确的,那么由于树的优良递归性质,其他的也会是正确的。

复习一下普通的背包问题,用 $dp_i$ 表示,在使用了 $i$ 空间的情况下的最大价值。

但是在树问题中,由于第一维度是节点的编号,所以我们用 $dp_{i,j}$ 表示对于 $i$ 子树来说,使用了 $j$ 空间的最大价值。

当然题目中有要求,必须满足依赖关系,所以,我们需要重新定义: $dp_{i,j}$ 表示对于 $i$ 子树来说,使用了 $j$ 空间且满足依赖关系的最大价值。

如何满足呢?

我们只需要保证每一个 $dp_{i,j}$ 都选了 $i$ 节点即可。

我们可以在背包中预留出节点 $i$ 的空间即可。

代码如下:

  • C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;

const int N = 1e2+20;
vector<int> G[N];
int n, V;
int v[N], w[N];
int dp[N][N];

void dfs(int u) {
for (int i = v[u]; i <= V; ++i)
dp[u][i] = w[u];
for (int i : G[u]) {
dfs(i);
for (int j = V; j >= v[u] + v[i]; --j) {
for (int k = v[i]; k <= j - v[u]; ++k) // 剩余的空间
dp[u][j] = max(dp[u][j - k] + dp[i][k], dp[u][j]);
}
}
}

int main() {
cin >> n >> V;
int s;
for (int i = 1; i <= n; ++i) {
cin >> v[i] >> w[i] >> s;
G[s].push_back(i);
}
dfs(0);
cout << dp[0][V] << '\n';
}
  • Java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import java.util.*;

public class Main {
private static int V;
private static int[][] dp;
private static List<List<Integer>> G;
private static int[] v;
private static int[] w;

private static void dfs(int u) {
for (int i = v[u]; i <= V; ++i) {
dp[u][i] = w[u];
}
for (int child : G.get(u)) {
dfs(child);
for (int j = V; j >= v[u] + v[child]; --j) {
for (int k = v[child]; k <= j - v[u]; ++k) {
dp[u][j] = Math.max(dp[u][j - k] + dp[child][k], dp[u][j]);
}
}
}
}

public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
V = scanner.nextInt();
G = new ArrayList<>();
for (int i = 0; i <= n; ++i) {
G.add(new ArrayList<>());
}
v = new int[n + 1];
w = new int[n + 1];
dp = new int[n + 1][V + 1];
for (int i = 1; i <= n; ++i) {
v[i] = scanner.nextInt();
w[i] = scanner.nextInt();
int s = scanner.nextInt();
G.get(s).add(i);
}
dfs(0);
System.out.println(dp[0][V]);
scanner.close();
}
}
  • Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution:
def dfs(self, u, dp, G, v, w, V):
for i in range(v[u], V + 1):
dp[u][i] = w[u]
for child in G[u]:
self.dfs(child, dp, G, v, w, V)
for j in range(V, v[u] + v[child] - 1, -1):
for k in range(v[child], j - v[u] + 1):
dp[u][j] = max(dp[u][j - k] + dp[child][k], dp[u][j])

def main(self):
n, V = map(int, input().split())
G = [[] for _ in range(n + 1)] # 0-indexed in Python
v = [0] * (n + 1)
w = [0] * (n + 1)
for i in range(1, n + 1):
v[i], w[i], s = map(int, input().split())
G[s].append(i)
dp = [[0] * (V + 1) for _ in range(n + 1)]
self.dfs(0, dp, G, v, w, V)
print(dp[0][V])

# Run the main function
solution = Solution()
solution.main()

4. 换根 DP 问题

换根 DP,面对的问题通常是“不定根”问题,也就是说,对于一棵树,他的根不一定是 $1$ 号点,可能是任意某个点。

或者在某些问题中,我们需要尝试计算以每个点为根的情况,最后维护出最大值。

我们先看一副图,来理解所谓的“换根”。

图片描述

我们将原来以 $1$ 为根换成了以 $2$ 为根。那么树的形态也就发生了变化。

如果每次都是选择一个点作为根进行处理,那么总的时间复杂度为 $O(n^2)$,但是如果我们能发现性质,我们可以将复杂度降为 $O(n)$。

即换一次根的复杂度为 $O(1)$,下面,我们将讲述这种方法。

在一般的问题中,我们常常是利用dfs来不断的将根转换为根的子节点。

我们会发现一些事情:

图片描述

我们一次转换的过程,其实有很大一部分并没有发生变化,体现在 DP 转移中,就是这些点的 DP 值也不会发生改变。

实际上改变的只有改变身份的两个点,其他的点都不会发生变化。

在换根的问题中,一般的步骤如下:

  1. 以 $1$ 为根进行一遍扫描,并且处理出必要的信息,例如深度、DP 值等。
  2. 开始以 $1$ 进行换根,并且向下递归,在递归之前,需要将自己变成子节点的身份。
  3. 进入新的根后,按照根的身份,重新进行转移。并且维护答案。

4.1 卖树

图片描述

本题需要计算以每个点为根的情况下,产生的盈利。

如果我们确定了一个点为根,我们很容易算出答案,如果确定了根,问题就变成了求最大深度,这个问题只需要一遍DFS就可以完成。

1
2
3
4
5
6
7
8
9
void dfs(int u, int f, int dt) { // 求出以1为根的原始信息
dep[u] = dt;
Mdp[u] = 0; // Mdp即为当前点为根的最大深度
for (int v : G[u]) {
if (v == f) continue;
dfs(v, u, dt + 1);
Mdp[u] = max(Mdp[v] + 1, Mdp[u]);
}
}

因为节点数量太多,我们无法承受 $O(n^2)$ 的复杂度,所以我们需要进行换根,

基本思想如上述一致:

  1. 我们需要先算出以 $1$ 为根的信息,包括以每个节点为子树的最大深度,从 $1$ 转移到 $i$ 节点的代价。
  2. 我们从 $1$ 号点开始换根,每次只将根的身份换给儿子,然后进入递归,进入之前,我们需要将当前点的身份改为子节点。
  3. 进行新的根,由于原来的转移已经失效,所以需要重新转移。并且维护答案,然后重复2步骤。
  • C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <iostream>
#include <vector>
using namespace std;

const int N = 1e5+10;
vector<int> G[N];
int n, k, c;
int dep[N], Mdp[N];
typedef long long ll;
ll ans = 0;

void dfs(int u, int f, int dt) { // 求出以1为根的原始信息
dep[u] = dt;
Mdp[u] = 0;
for (int v : G[u]) {
if (v == f) continue;
dfs(v, u, dt + 1);
Mdp[u] = max(Mdp[v] + 1, Mdp[u]);
}
}

void dfs2(int u, int f) { // 开始换根
int tmpf = 0, Mx1 = 0, Mx2 = 0;
for (int v : G[u]) {
tmpf = max(tmpf, Mdp[v] + 1);
}
ans = max(1ll * tmpf * k - 1ll * dep[u] * c, ans);
int pre = Mdp[u];
for (int v : G[u]) {
if (Mdp[v] + 1 > Mx1) {
Mx2 = Mx1;
Mx1 = Mdp[v] + 1;
} else if (Mdp[v] + 1 > Mx2) {
Mx2 = Mdp[v] + 1;
}
}
for (int v : G[u]) {
if (v == f) continue;
if (Mdp[v] + 1 == Mx1) Mdp[u] = Mx2;
else Mdp[u] = Mx1;
dfs2(v, u);
}
Mdp[u] = pre;
}

void sol() {
for (int i = 1; i <= n; ++i) G[i].clear();
ans = 0;
cin >> n >> k >> c;
int u, v;
for (int i = 1; i < n; ++i) {
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0, 0);
dfs2(1, 0);
cout << ans << '\n';
}

int main() {
ios::sync_with_stdio(0);
int T;
cin >> T;
while (T--) {
sol();
}
return 0;
}
  • Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from collections import defaultdict
import sys
sys.setrecursionlimit(100000)

N = 100010
G = defaultdict(list)
n, k, c = 0, 0, 0
dep = [0] * N
Mdp = [0] * N
ans = 0

def dfs(u, f, dt): # 求出以1为根的原始信息
global dep, Mdp
dep[u] = dt
Mdp[u] = 0
for v in G[u]:
if v == f:
continue
dfs(v, u, dt + 1)
Mdp[u] = max(Mdp[v] + 1, Mdp[u])

def dfs2(u, f): # 开始换根
global ans, dep, Mdp
tmpf = 0
Mx1 = 0
Mx2 = 0
for v in G[u]:
tmpf = max(tmpf, Mdp[v] + 1)
ans = max(ans, tmpf * k - dep[u] * c)
pre = Mdp[u]
for v in G[u]:
if Mdp[v] + 1 > Mx1:
Mx2 = Mx1
Mx1 = Mdp[v] + 1
elif Mdp[v] + 1 > Mx2:
Mx2 = Mdp[v] + 1
for v in G[u]:
if v == f:
continue
if Mdp[v] + 1 == Mx1:
Mdp[u] = Mx2
else:
Mdp[u] = Mx1
dfs2(v, u)
Mdp[u] = pre

def sol():
global n, k, c, ans, G, dep, Mdp
n, k, c = map(int, input().split())
G.clear()
ans = 0
for _ in range(n - 1):
u, v = map(int, input().split())
G[u].append(v)
G[v].append(u)
dfs(1, 0, 0)
dfs2(1, 0)
print(ans)

T = int(input())
for _ in range(T):
sol()
  • Java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import java.util.*;
import java.io.*;

public class Main {
static final int N = 100010;
static List<Integer>[] G;
static int n, k, c;
static int[] dep, Mdp;
static long ans;

static void dfs(int u, int f, int dt) { // 求出以1为根的原始信息
Mdp[u] = 0;
for (int v : G[u]) {
if (v == f) continue;
dfs(v, u, dt + 1);
Mdp[u] = Math.max(Mdp[v] + 1, Mdp[u]);
}
}

static void dfs2(int u, int f) { // 开始换根
int tmpf = 0, Mx1 = 0, Mx2 = 0;
for (int v : G[u]) {
tmpf = Math.max(tmpf, Mdp[v] + 1);
}
ans = Math.max(ans, (long) tmpf * k - (long) dep[u] * c);
int pre = Mdp[u];
for (int v : G[u]) {
if (Mdp[v] + 1 > Mx1) {
Mx2 = Mx1;
Mx1 = Mdp[v] + 1;
} else if (Mdp[v] + 1 > Mx2) {
Mx2 = Mdp[v] + 1;
}
}
for (int v : G[u]) {
if (v == f) continue;
if (Mdp[v] + 1 == Mx1) {
Mdp[u] = Mx2;
} else {
Mdp[u] = Mx1;
}
dfs2(v, u);
}
Mdp[u] = pre;
}

static void sol(Scanner scanner) {
for (int i = 1; i <= n; i++) G[i].clear();
ans = 0;
n = scanner.nextInt();
k = scanner.nextInt();
c = scanner.nextInt();
int u, v;
for (int i = 1; i < n; i++) {
u = scanner.nextInt();
v = scanner.nextInt();
G[u].add(v);
G[v].add(u);
}
dfs(1, 0, 0);
dfs2(1, 0);
System.out.println(ans);
}

public static void main(String[] args) {
G = new ArrayList[N];
for (int i = 0; i < N; i++) {
G[i] = new ArrayList<>();
}
dep = new int[N];
Mdp = new int[N];
Scanner scanner = new Scanner(System.in);
int T = scanner.nextInt();
while (T-- > 0) {
sol(scanner);
}
}
}

5. 作业

题目 链接
取气球(算法赛) https://www.lanqiao.cn/problems/17024/learning/
左孩子右兄弟(21 年省赛) https://www.lanqiao.cn/problems/1451/learning/