51nod 1405 树的距离之和 - dblank

51nod 1405 树的距离之和

给定一棵无根树,假设它有n个节点,节点编号从1到n, 求任意两点之间的距离(最短路径)之和。
Input
第一行包含一个正整数n (n <= 100000),表示节点个数。
后面(n - 1)行,每行两个整数表示树的边。
Output
每行一个整数,第i(i = 1,2,...n)行表示所有节点到第i个点的距离之和。
Input示例
4
1 2
3 2
4 2
Output示例
5
3
5
5
数学题,也可以说dp,不太难。
(1)我们给树规定一个根。假设所有节点编号是0-(n-1),我们可以简单地把0当作根,这样下来父子关系就确定了。
(2)定义数组num[x]表示以节点x为根的子树有多少个节点,dp[x]是我们所求的——所有节点到节点x的距离之和。
(3)在步骤(1)中,其实我们同时可以计算出 num[x],还可以计算出每个节点的深度(每个到根节点0的距离),累加全部节点深度得到的其实就是是dp[0]。
(4) 假设一个非根节点x,它的父亲节点是y, 并且dp[y]已经计算好了,我们如何计算dp[x]?
以x为根的子树中那些节点,到x的距离比到y的距离少1, 这样的节点有num[x]个。
其余节点到x的距离比到y的距离多1,这样的节点有(n - num[x])个。
于是我们有 dp[x] = dp[y] - num[x] + (n - num[x])
= dp[y] + n - num[x] * 2
因为树的根节点dp[0]在步骤(3)已经计算出来了,根据所有的父子关系和这个上式,我们可以按照顺序计算出整个dp数组。
#include<iostream>  
#include<algorithm>  
#include<cstdio>  
#include<queue>  
#include<map>  
#include<vector>  
#include<cstring>  
#include<cmath>  
#pragma comment(linker, "/STACK:10240000,10240000")  //手动扩栈
using namespace std;  
typedef long long ll;  
const int inf =0x3f3f3f3f;  
const double  pi = acos(-1.0);  
const int N = 1e5 + 10;  
int root[N], num[N], len[N], vis[N], n;  
vector<int>path[N];  
ll dp[N];  
void dfs1(int x)  
{  
    num[x]  = 1;  
    for(int i = 0, sz = (int)path[x].size(); i<sz; i++)  
    {  
        int y = path[x][i];  
        if(!vis[y])  
        {  
            vis[y] = 1;  
            len[y] = len[x] + 1;  
            dp[1] += len[y];  
            dfs1(y);  
            num[x] += num[y];  
        }  
    }  
}  
void dfs2(int x)  
{  
    for(int i = 0, sz = (int)path[x].size(); i<sz; i++)  
    {  
        int y = path[x][i];  
        if(!vis[y])  
        {  
            dp[y] = dp[x] - num[y] + n - num[y];  
            vis[y] = 1;  
            dfs2(y);  
        }  
    }  
}  
int main()  
{  
    scanf("%d", &n);  
    int u, v;  
    for(int i = 1; i<n; i++)  
    {  
        scanf("%d%d", &u, &v);  
        path[u].push_back(v);  
        path[v].push_back(u);  
    }  
    vis[1] = 1;  
    dfs1(1);  
    memset(vis, 0, sizeof(vis));  
  //  for(int i = 1; i<=n; i++)  
  //      printf("%d\n", num[i]);  
    vis[1] = 1;  
    dfs2(1);  
    for(int i = 1; i<=n; i++)  
        printf("%I64d\n", dp[i]);  
        return 0;  
}

 

相关文章

发表新评论