树形DP泛做

发布于:2022-11-01 ⋅ 阅读:(466) ⋅ 点赞:(0)

因为树形DP也算是dp里面的一种类型,虽然感觉没有什么总结的必要,但是碰到一些经典的套路还是想要记录一下,所以发在这个博客里

树形dp可以分为两种类型,一个是选择节点类,一个是树上背包类,大致如下

1. 牛客小白月赛55F

题目描述:

给出一个拓扑序关系的树型图,问图上的节点有多少种排列可能

思路:

首先考虑下面这个问题:有a个红气球,b个绿气球,c个蓝气球有多少种排列组合,这就是一个多重全排列问题 ans=(a+b+c)!/a!/b!/c!  即为所有气球的全排列除以各个相同颜色的气球的全排列,即各个子集和的全排列除以各个子集和的限制条件    

那么对于这一题,用这个思想,他们满足以下关系

代码:

#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=2e5+5;
const int inf=1e9+7;
const int mod=1e9+7;
vector<int>vec[maxn];
int fac[maxn],inv[maxn];
void init(int n=1e5+5){
	fac[0]=1;
	fac[1]=1;
	for(int i=2;i<=n;i++){
		fac[i]=fac[i-1]*i%mod;
	}
	inv[1]=1;
	for(int i=2;i<=n;i++){
		inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	}
	for(int i=2;i<=n;i++){
		inv[i]=inv[i]*inv[i-1]%mod;
	}
}
int dp[maxn];
int siz[maxn];
void dfs1(int u,int fa){
	siz[u]=1;
	for(int v:vec[u]){
		if(v!=fa){
			dfs1(v,u);
			siz[u]+=siz[v];
		}
	}
}
void dfs(int u,int fa){
	dp[u]=fac[siz[u]-1];
	for(int v:vec[u]){
		if(v!=fa){
			dfs(v,u);
			dp[u]=dp[u]*dp[v]%mod;
			dp[u]=dp[u]*inv[siz[v]]%mod;
		}
	}
}
int dpp[maxn];
int sizz[maxn];
void solve(){
	int n;
	cin>>n;
	int zong=0;
	for(int i=1;i<=n;i++){
		int m;
		cin>>m;
		for(int j=0;j<=m;j++){
			vec[j].clear();		
		}
		for(int j=2;j<=m;j++){
			int u,v;
			cin>>u;
			v=j;
			vec[u].push_back(v);
			vec[v].push_back(u);
		}
		dfs1(1,0);
		dfs(1,0);
		dpp[i]=dp[1];
		sizz[i]=siz[1];
		zong+=sizz[i];
	}
	int ans=fac[zong];
	for(int i=1;i<=n;i++){
		ans=ans*dpp[i]%mod;
		ans=ans*inv[sizz[i]]%mod;
	}
	cout<<ans<<"\n";
}
signed main(){
	init();
	int t=1;
	//cin>>t;
	while(t--){
		solve();
	}
}

2.Codeforces Round #686 (Div. 3) E

这题是一个基环树,即n个节点n条边,那么一定会构成一个环,把这个环上面的各个点当作树的根节点,那么就是一个环上面有若干颗树

一个经典的套路:树里面任意两点都有一条简单路径,所以假设一颗树有n个节点,则有n*(n-1)/2条简单路径

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
set<int> G[N];
ll vis[N];
int main() {
  ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
  int T; cin >> T;
  while(T--) {
    int n; cin >> n;
    for(int i = 1; i <= n; i++) G[i].clear(), vis[i] = 1;
    for(int i = 1; i <= n; i++) {
      int u, v; cin >> u >> v;
      G[u].insert(v);
      G[v].insert(u);
    }
    queue<int> q;
    for(int i = 1; i <= n; i++) {
      if(G[i].size() == 1) {
        q.push(i);
        vis[i] = 1;
      }
    }
    while(!q.empty()) {
      int u = q.front(); q.pop();
      for(auto &v : G[u]) {
        vis[v] += vis[u];
        vis[u] = 0;
        G[v].erase(u);
        if(G[v].size() == 1) {
          q.push(v);
        }
      }
    }  
    ll ans = 0;
    for(int i = 1; i <= n; i++) {
      ans += (vis[i] - 1) * vis[i] / 2;
      ans += vis[i] * (n - vis[i]);
    } 
    cout << ans << '\n';
  }  
  return 0;
}

3.P2014 [CTSC1997] 选课

 树上背包板子题

此题的多叉转二叉优化代码

#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=1005;
const int inf=1e9+7;
const int mod=1e9+7;
int n,m;
vector<int>vec[maxn];
int val[maxn];
int dp[maxn][maxn];  //代表以i为根节点的子树中选择j个节点的最大价值
void dfs(int u,int t){
	if(t<=0){
		return;
	}
	for(int v:vec[u]){
		for(int k=0;k<t;k++){
			dp[v][k]=dp[u][k]+val[v];
		}
		dfs(v,t-1);
		for(int k=1;k<=t;k++){
			dp[u][k]=max(dp[u][k],dp[v][k-1]);
		}
	}
}
void solve(){
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		int fa;
		cin>>fa;
		cin>>val[i];
		vec[fa].push_back(i);
	}
	dfs(0,m);
	cout<<dp[0][m]<<"\n";
}
signed main(){
	int t=1;
	//cin>>t;
	while(t--){
		solve();
	}
}

此题的上下界优化代码

#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=1005;
const int inf=1e9+7;
const int mod=1e9+7;
int n,m;
vector<int>vec[maxn];
int val[maxn];
int dp[maxn][maxn];  //代表以i为根节点的子树中选择j个节点的最大价值
int siz[maxn];
void dfs(int u){
	siz[u]=1;
	dp[u][1]=val[u];
	for(int v:vec[u]){
		dfs(v);
		for (int j=min(m+1,siz[u]+siz[v]);j>=1;--j)
		{
			for (int k=max(1ll,j-siz[u]);k<=siz[v]&&k<j;++k)
			{
				dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[v][k]);
			}
		}
		siz[u]+=siz[v];
	}
}
void solve(){
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		int fa;
		cin>>fa;
		cin>>val[i];
		vec[fa].push_back(i);
	}
	dfs(0);
	cout<<dp[0][m+1]<<"\n";  //选择了第0号节点所以把m变成m+1
}
signed main(){
	int t=1;
	//cin>>t;
	while(t--){
		solve();
	}
}

4.求断开任意一条边后,两颗子树分别的最长直径

#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
#define maxn 200005
vector<int>vec[maxn];
int n,u[maxn],v[maxn],a[maxn],down[maxn],up[maxn],dp[maxn][4],len[maxn][2],d[maxn];
/*
第一次dfs预先处理出的数据dp[u][0|1|2]表示u的最长|次长|次次长 链
down[u]表示u子树里的最长链长度 */
void dfs1(int u,int pre,int dep){
	d[u]=dep;    
	for(int i=0;i<vec[u].size();i++){
		int v=vec[u][i];
		if(v==pre)continue;
		dfs1(v,u,dep+1);
		int tmp=dp[v][0]+1;
		for(int i=0;i<=2;i++){
			if(tmp>dp[u][i]){
				swap(dp[u][i],tmp);
			}
		}
		down[u]=max(down[u],down[v]); 
	}
	down[u]=max(down[u],dp[u][0]+dp[u][1]);
}
/*
第二次dfs求出
len[u][0|1]表示u子树里的不经过u的最长|次长 链
枚举每个儿子v,
求出dp[v][3]表示v上面的最长链
up[v]表示切断(u,v)后,u所在块的最长链
那么切断(u,v)后,两个子树的直径就是down[v],up[v]
down[v]好求,up[v]要通过u来求出
换根的过程:从u换到v时,up[v]有两种情况,一种是一条u的上面+一条u的下面,另一种是两条u的下面 */
void dfs2(int u,int pre){
	for(int i=0;i<vec[u].size();i++){//求出u下不经过u的最长|次长 链 
		int v=vec[u][i];
		if(v==pre)continue;
		int tmp=down[v];
		if(tmp>len[u][0])swap(tmp,len[u][0]);
		if(tmp>len[u][1])swap(tmp,len[u][1]);
	}	
	for(int i=0;i<vec[u].size();i++){//边(u,v)将原树分成两棵子树 
		int v=vec[u][i];
		if(v==pre)continue; 
		//原树删掉v子树后,求v上的最长链dp[v][3],同时求出第一种情况的up[v] 
		if(dp[u][0]==dp[v][0]+1){//v是u的最深子树 
			dp[v][3]=max(dp[u][3],dp[u][1])+1;
			up[v]=max(dp[u][2],dp[u][3])+dp[u][1]; 
		} 
		else if(dp[u][1]==dp[v][0]+1){//v是u的次深子树 
			dp[v][3]=max(dp[u][0],dp[u][3])+1; 
			up[v]=max(dp[u][2],dp[u][3])+dp[u][0];
		}
		else {//v是u的其他子树 
			dp[v][3]=max(dp[u][0],dp[u][3])+1;
			up[v]=max(dp[u][1],dp[u][3])+dp[u][0]; 
		}
		//求第二种情况的up[v],也要特别判一下v子树里是否有u下的最长链 
		if(len[u][0]==down[v])up[v]=max(up[v],len[u][1]); 
		else up[v]=max(up[v],len[u][0]);
		dfs2(v,u);
	}
}
void init(){
	memset(dp,0,sizeof dp);
	memset(a,0,sizeof a);
	memset(len,0,sizeof len);
	memset(down,0,sizeof down);
	memset(up,0,sizeof up);
	for(int i=1;i<=n;i++)vec[i].clear();
}
void solve(){
	init();
	cin>>n;
	for(int i=1;i<n;i++){
		cin>>u[i]>>v[i];
		vec[u[i]].push_back(v[i]);
		vec[v[i]].push_back(u[i]);
	}
	dfs1(1,1,0);
	dfs2(1,1);
	for(int i=1;i<n;i++){
		int x=u[i],y=v[i];
		if(d[x]<d[y])swap(x,y);
		a[up[x]+1]=max(a[up[x]+1],down[x]+1);
		a[down[x]+1]=max(a[down[x]+1],up[x]+1);
	}
	long long ans=0;
	for(int i=n;i>=1;i--){
		a[i]=max(a[i],a[i+1]);
		ans+=a[i];
	}
	cout<<ans<<"\n";
}
signed main(){
	io;
	int t=1;
	cin>>t;
	while(t--){
		solve();
	}
}


网站公告

今日签到

点亮在社区的每一天
去签到