"蔚来杯"2022牛客暑期多校训练营5-D Birds in the tree
原题题面:https://ac.nowcoder.com/acm/contest/33190/D
题目大意
给定一个有 n n n个节点的树,每个节点有一个颜色,分别有 0 0 0和 1 1 1。
求这颗树里有多少连通的子图,其中度数为 1 1 1的节点的颜色相同。
注意:符合要求的子图包括单独一个节点
由于答案可能过大,请对 1 e 9 + 7 1e9+7 1e9+7进行取模然后输出
解题思路
想到树形 d p dp dp,可以对每个节点 x x x与每个颜色 c c c做处理,设 d p x , c dp_{x,c} dpx,c为从以 x x x为根的子树中的 x x x被选择,除 x x x外叶节点颜色c的方案数。由于若 x x x为叶节点会影响后面的转移,所以排除。
对于 d p x , c dp_{x,c} dpx,c,其子树的方案数为从 d p s o n , c dp_{son,c} dpson,c个方案中任取一个或不取,根据乘法原理,方案总数 d p x , c = ∏ s ∈ s o n x ( d p s , c + 1 ) dp_{x,c}=\prod_{s\in son_x}(dp_{s,c}+1) dpx,c=s∈sonx∏(dps,c+1)
设 x x x的颜色为 c o l x col_x colx,最终答案先加上以 x x x为叶节点的方案数,加上只有 x x x的图,再加上不以 x x x为叶节点的方案,去掉其中与子节点重复的方案与只有 x x x的图。得:
a n s = ∑ x ( ∑ s ∈ s n x d p s , c o l x + 1 ) + ∑ x ∑ c ∈ { 0 , 1 } ( d p x , c − ∑ s ∈ s o n x d p s , c − ( c o l c = = c ) ) \begin{aligned} ans&=\sum_x(\sum_{s\in sn_x}dp_{s,col_x}+1)\\ &+\sum_{x}\sum_{c\in\{ 0,1\}}(dp_{x,c}-\sum_{s\in son_x}dp_{s,c}-(col_c==c))\\ \end{aligned} ans=x∑(s∈snx∑dps,colx+1)+x∑c∈{0,1}∑(dpx,c−s∈sonx∑dps,c−(colc==c))
∑ x ( + 1 ) \sum_x(+1) ∑x(+1)与 ∑ x ( − ( c o l c = = c ) ) \sum_x (-(col_c==c)) ∑x(−(colc==c))可以抵消
∑ x ( ∑ s ∈ s n x d p s , c o l x ) \sum_x(\sum_{s\in sn_x}dp_{s,col_x}) ∑x(∑s∈snxdps,colx)与 ∑ x ( − ∑ s ∈ s o n x d p s , c ) \sum_x(-\sum_{s\in son_x}dp_{s,c}) ∑x(−∑s∈sonxdps,c)可以消掉一部分
然后整理一下:
S ∈ { 0 , 1 } a n s = ∑ x ( ∑ c ∈ { 0 , 1 } d p x , c − ∑ s ∈ s o n x d p s , ∁ S c o l x ) \begin{aligned} S&\in\{0,1\}\\ ans&=\sum_x(\sum_{c\in\{0,1\}}dp_{x,c}-\sum_{s\in son_x}dp_{s,\complement_Scol_x}) \end{aligned} Sans∈{0,1}=x∑(c∈{0,1}∑dpx,c−s∈sonx∑dps,∁Scolx)
代码实现
注意取模,可以开 l o n g l o n g longlong longlong保险一点。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=3e5+5,mod=1e9+7;
ll dp[N][2],ans;
vector<int>G[N];
int n;
char s[N];
void dfs(int u,int fa){
dp[u][0]=dp[u][1]=1;
for(auto son:G[u]){
if(son==fa)continue;
dfs(son,u);
dp[u][1]=(1ll*dp[u][1]*(dp[son][1]+1))%mod;
dp[u][0]=(1ll*dp[u][0]*(dp[son][0]+1))%mod;
ans=((ans-dp[son][1^(s[u]^48)])%mod+mod)%mod;
}
dp[u][1^(s[u]^48)]--;
ans=(ans+dp[u][0]+dp[u][1])%mod;
}
int main(){
cin>>n;
scanf("%s",s+1);
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
G[x].push_back(y);
G[y].push_back(x);
}
dfs(1,0);
cout<<ans;
}