HDU - 6035, point divide, combined count

Colorful Tree

https://vjudge.net/problem/938230/origin
There is a tree with n nodes, each of which has a type of color represented by an integer, where the color of node i is ci.

The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.

Calculate the sum of values of all paths on the tree that has n(n−1)2 paths in total.

There are n points in a tree for you. Each point has a color (1 ~ n). The weight of a path is the number of different colors on the path. Ask the sum of the weight of all paths
Idea: first, when we divide and conquer, we need to contribute to all paths passing through a center of gravity root. This is a tree with the center of gravity as the root node. For one of the subtrees, if id1 and id2 have the same color and id1 is the ancestor of id2, then all paths passing through id2 (and root) must pass through id1, and only id1 color will contribute. As long as we Calculate the contribution of each color that first appears in the subtree (that is, the number of paths passing through the point). If id is the node that first appears in a subtree of a certain color, the number of nodes in the subtree with it as the root node is sz[id], point v is the root node of the whole subtree, val[id] is the color of node id, and the initial value of num[val[id] is the number of nodes in the whole tree where the center of gravity root is located, then the node The contribution of V is (num[val[id] - SZ [v]) * sz[id] (because it passes through the center of gravity and node id), and then the value of num[val[id]] minus sz[id], because when calculating the contribution of a certain color of other subtrees, if the path passes through the child nodes of id, it will generate repeated calculation. Note that if the two nodes have the same color but no ancestor relationship, the contribution of both nodes must be calculated (here wa one Hair)

#include<bits/stdc++.h>
#define MAXN 200010
#define INF 0x3f3f3f3f
#define ll long long
using namespace std;
int head[MAXN],tot;
struct edge
{
    int v,nxt;
}edg[MAXN << 1];
inline void addedg(int u,int v)
{
    edg[tot].v = v;
    edg[tot].nxt = head[u];
    head[u] = tot++;
}
int n,mx,root,Size,sz[MAXN];
ll ans;
int num[MAXN],val[MAXN];//num records how many times each color should be in the subtree being processed
bool vis[MAXN];
inline void getroot(int u,int f)
{
    int v,mson = 0;
    sz[u] = 1;
    num[val[u]] = Size;
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(v == f || vis[v]) continue;
        getroot(v,u);
        sz[u] += sz[v];
        mson = max(mson,sz[v]);
    }
    mson = max(Size-sz[u],mson);
    if(mson < mx)
        mx = mson,root = u;
}
int color[MAXN],id[MAXN],cnt,viscolor[MAXN],num1[MAXN];//num1 records the path length from each subtree to the root node except the node with the same color as the root node appears for the first time
inline void getdis(int u,int f)
{
    sz[u] = 1;
    bool flag = false;
    if(!viscolor[val[u]])
        viscolor[val[u]] = 1,flag = true,color[++cnt] = val[u],id[cnt] = u;
    int v;
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(v == f || vis[v]) continue;
        getdis(v,u);
        sz[u] += sz[v];
    }
    if(flag)
        viscolor[val[u]] = 0;
}
inline void solve(int u,int ssize)
{
    vis[u] = 1;
    int v;
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(vis[v]) continue;
        cnt = 0;
        getdis(v,v);
        num1[v] = sz[v];
        for(int i = 1;i <= cnt;++i)
        {
            int nn = id[i];
            ans += 1ll * sz[nn] * (num[color[i]]-sz[v]);
            if(color[i] == val[u])
                num1[v] -= sz[nn];
        }
        for(int i = 1;i <= cnt;++i)
            num[color[i]] -= sz[id[i]];
    }
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(vis[v]) continue;
        ans += 1ll*num1[v] * (num[val[u]] - num1[v]);
        num[val[u]] -= num1[v];
    }
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(vis[v]) continue;
        Size = sz[v];
        mx = INF;
        getroot(v,v);
        solve(root,Size);
    }
}
inline void init()
{
    tot = ans = 0,Size = n,mx = INF;
    memset(head,-1,sizeof(int)*(n+1));
    memset(vis,false,sizeof(bool)*(n+1));
}
int main()
{
    int t = 0;
    while(~scanf("%d",&n))
    {
        ++t;
        init();
        for(int i = 1;i <= n;++i)
            scanf("%d",&val[i]);
        int u,v;
        for(int i = 1;i < n;++i)
        {
            scanf("%d%d",&u,&v);
            addedg(u,v),addedg(v,u);
        }
        getroot(1,1);
        solve(root,Size);
        printf("Case #%d: %lld\n",t,ans);
    }
    return 0;
}
/*
7
6 3 3 1 1 1 2
2 1
3 1
4 1
5 4
6 5
7 5

 */
Published 47 original articles, won praise 1, visited 2546
Private letter follow

Added by akluch on Sat, 18 Jan 2020 13:02:57 +0200