因为树形DP也算是dp里面的一种类型,虽然感觉没有什么总结的必要,但是碰到一些经典的套路还是想要记录一下,所以发在这个博客里
树形dp可以分为两种类型,一个是选择节点类,一个是树上背包类,大致如下
1. 牛客小白月赛55F
题目描述:
给出一个拓扑序关系的树型图,问图上的节点有多少种排列可能
思路:
首先考虑下面这个问题:有a个红气球,b个绿气球,c个蓝气球有多少种排列组合,这就是一个多重全排列问题 ans=(a+b+c)!/a!/b!/c! 即为所有气球的全排列除以各个相同颜色的气球的全排列,即各个子集和的全排列除以各个子集和的限制条件
那么对于这一题,用这个思想,他们满足以下关系
代码:
#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=2e5+5;
const int inf=1e9+7;
const int mod=1e9+7;
vector<int>vec[maxn];
int fac[maxn],inv[maxn];
void init(int n=1e5+5){
fac[0]=1;
fac[1]=1;
for(int i=2;i<=n;i++){
fac[i]=fac[i-1]*i%mod;
}
inv[1]=1;
for(int i=2;i<=n;i++){
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=2;i<=n;i++){
inv[i]=inv[i]*inv[i-1]%mod;
}
}
int dp[maxn];
int siz[maxn];
void dfs1(int u,int fa){
siz[u]=1;
for(int v:vec[u]){
if(v!=fa){
dfs1(v,u);
siz[u]+=siz[v];
}
}
}
void dfs(int u,int fa){
dp[u]=fac[siz[u]-1];
for(int v:vec[u]){
if(v!=fa){
dfs(v,u);
dp[u]=dp[u]*dp[v]%mod;
dp[u]=dp[u]*inv[siz[v]]%mod;
}
}
}
int dpp[maxn];
int sizz[maxn];
void solve(){
int n;
cin>>n;
int zong=0;
for(int i=1;i<=n;i++){
int m;
cin>>m;
for(int j=0;j<=m;j++){
vec[j].clear();
}
for(int j=2;j<=m;j++){
int u,v;
cin>>u;
v=j;
vec[u].push_back(v);
vec[v].push_back(u);
}
dfs1(1,0);
dfs(1,0);
dpp[i]=dp[1];
sizz[i]=siz[1];
zong+=sizz[i];
}
int ans=fac[zong];
for(int i=1;i<=n;i++){
ans=ans*dpp[i]%mod;
ans=ans*inv[sizz[i]]%mod;
}
cout<<ans<<"\n";
}
signed main(){
init();
int t=1;
//cin>>t;
while(t--){
solve();
}
}
2.Codeforces Round #686 (Div. 3) E
这题是一个基环树,即n个节点n条边,那么一定会构成一个环,把这个环上面的各个点当作树的根节点,那么就是一个环上面有若干颗树
一个经典的套路:树里面任意两点都有一条简单路径,所以假设一颗树有n个节点,则有n*(n-1)/2条简单路径
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
set<int> G[N];
ll vis[N];
int main() {
ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
int T; cin >> T;
while(T--) {
int n; cin >> n;
for(int i = 1; i <= n; i++) G[i].clear(), vis[i] = 1;
for(int i = 1; i <= n; i++) {
int u, v; cin >> u >> v;
G[u].insert(v);
G[v].insert(u);
}
queue<int> q;
for(int i = 1; i <= n; i++) {
if(G[i].size() == 1) {
q.push(i);
vis[i] = 1;
}
}
while(!q.empty()) {
int u = q.front(); q.pop();
for(auto &v : G[u]) {
vis[v] += vis[u];
vis[u] = 0;
G[v].erase(u);
if(G[v].size() == 1) {
q.push(v);
}
}
}
ll ans = 0;
for(int i = 1; i <= n; i++) {
ans += (vis[i] - 1) * vis[i] / 2;
ans += vis[i] * (n - vis[i]);
}
cout << ans << '\n';
}
return 0;
}
树上背包板子题
此题的多叉转二叉优化代码
#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=1005;
const int inf=1e9+7;
const int mod=1e9+7;
int n,m;
vector<int>vec[maxn];
int val[maxn];
int dp[maxn][maxn]; //代表以i为根节点的子树中选择j个节点的最大价值
void dfs(int u,int t){
if(t<=0){
return;
}
for(int v:vec[u]){
for(int k=0;k<t;k++){
dp[v][k]=dp[u][k]+val[v];
}
dfs(v,t-1);
for(int k=1;k<=t;k++){
dp[u][k]=max(dp[u][k],dp[v][k-1]);
}
}
}
void solve(){
cin>>n>>m;
for(int i=1;i<=n;i++){
int fa;
cin>>fa;
cin>>val[i];
vec[fa].push_back(i);
}
dfs(0,m);
cout<<dp[0][m]<<"\n";
}
signed main(){
int t=1;
//cin>>t;
while(t--){
solve();
}
}
此题的上下界优化代码
#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=1005;
const int inf=1e9+7;
const int mod=1e9+7;
int n,m;
vector<int>vec[maxn];
int val[maxn];
int dp[maxn][maxn]; //代表以i为根节点的子树中选择j个节点的最大价值
int siz[maxn];
void dfs(int u){
siz[u]=1;
dp[u][1]=val[u];
for(int v:vec[u]){
dfs(v);
for (int j=min(m+1,siz[u]+siz[v]);j>=1;--j)
{
for (int k=max(1ll,j-siz[u]);k<=siz[v]&&k<j;++k)
{
dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[v][k]);
}
}
siz[u]+=siz[v];
}
}
void solve(){
cin>>n>>m;
for(int i=1;i<=n;i++){
int fa;
cin>>fa;
cin>>val[i];
vec[fa].push_back(i);
}
dfs(0);
cout<<dp[0][m+1]<<"\n"; //选择了第0号节点所以把m变成m+1
}
signed main(){
int t=1;
//cin>>t;
while(t--){
solve();
}
}
4.求断开任意一条边后,两颗子树分别的最长直径
#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
#define maxn 200005
vector<int>vec[maxn];
int n,u[maxn],v[maxn],a[maxn],down[maxn],up[maxn],dp[maxn][4],len[maxn][2],d[maxn];
/*
第一次dfs预先处理出的数据dp[u][0|1|2]表示u的最长|次长|次次长 链
down[u]表示u子树里的最长链长度 */
void dfs1(int u,int pre,int dep){
d[u]=dep;
for(int i=0;i<vec[u].size();i++){
int v=vec[u][i];
if(v==pre)continue;
dfs1(v,u,dep+1);
int tmp=dp[v][0]+1;
for(int i=0;i<=2;i++){
if(tmp>dp[u][i]){
swap(dp[u][i],tmp);
}
}
down[u]=max(down[u],down[v]);
}
down[u]=max(down[u],dp[u][0]+dp[u][1]);
}
/*
第二次dfs求出
len[u][0|1]表示u子树里的不经过u的最长|次长 链
枚举每个儿子v,
求出dp[v][3]表示v上面的最长链
up[v]表示切断(u,v)后,u所在块的最长链
那么切断(u,v)后,两个子树的直径就是down[v],up[v]
down[v]好求,up[v]要通过u来求出
换根的过程:从u换到v时,up[v]有两种情况,一种是一条u的上面+一条u的下面,另一种是两条u的下面 */
void dfs2(int u,int pre){
for(int i=0;i<vec[u].size();i++){//求出u下不经过u的最长|次长 链
int v=vec[u][i];
if(v==pre)continue;
int tmp=down[v];
if(tmp>len[u][0])swap(tmp,len[u][0]);
if(tmp>len[u][1])swap(tmp,len[u][1]);
}
for(int i=0;i<vec[u].size();i++){//边(u,v)将原树分成两棵子树
int v=vec[u][i];
if(v==pre)continue;
//原树删掉v子树后,求v上的最长链dp[v][3],同时求出第一种情况的up[v]
if(dp[u][0]==dp[v][0]+1){//v是u的最深子树
dp[v][3]=max(dp[u][3],dp[u][1])+1;
up[v]=max(dp[u][2],dp[u][3])+dp[u][1];
}
else if(dp[u][1]==dp[v][0]+1){//v是u的次深子树
dp[v][3]=max(dp[u][0],dp[u][3])+1;
up[v]=max(dp[u][2],dp[u][3])+dp[u][0];
}
else {//v是u的其他子树
dp[v][3]=max(dp[u][0],dp[u][3])+1;
up[v]=max(dp[u][1],dp[u][3])+dp[u][0];
}
//求第二种情况的up[v],也要特别判一下v子树里是否有u下的最长链
if(len[u][0]==down[v])up[v]=max(up[v],len[u][1]);
else up[v]=max(up[v],len[u][0]);
dfs2(v,u);
}
}
void init(){
memset(dp,0,sizeof dp);
memset(a,0,sizeof a);
memset(len,0,sizeof len);
memset(down,0,sizeof down);
memset(up,0,sizeof up);
for(int i=1;i<=n;i++)vec[i].clear();
}
void solve(){
init();
cin>>n;
for(int i=1;i<n;i++){
cin>>u[i]>>v[i];
vec[u[i]].push_back(v[i]);
vec[v[i]].push_back(u[i]);
}
dfs1(1,1,0);
dfs2(1,1);
for(int i=1;i<n;i++){
int x=u[i],y=v[i];
if(d[x]<d[y])swap(x,y);
a[up[x]+1]=max(a[up[x]+1],down[x]+1);
a[down[x]+1]=max(a[down[x]+1],up[x]+1);
}
long long ans=0;
for(int i=n;i>=1;i--){
a[i]=max(a[i],a[i+1]);
ans+=a[i];
}
cout<<ans<<"\n";
}
signed main(){
io;
int t=1;
cin>>t;
while(t--){
solve();
}
}