Description
现有一棵 n 个节点的棵, 树上每条边的长度均为 1。 给出 m 个询问, 每次询问两个节点 x,y, 求树上到 x,y 两个点距离相同的节点数量。
Input
第一个整数 n, 表示树有 n 个点。
接下来 n-1 行每行两整数 a, b, 表示从 a 到 b 有一条边。
接下来一行一个整数 m, 表示有 m 个询问。
接下来 m 行每行两整数 x, y, 询问到 x 和 y 距离相同的点的数量。
Output
共 m 行, 每行一个整数表示询问的答案。
Sample Input 1
7 1 2 1 3 2 4 2 5 3 6 3 7 3 1 2 4 5 2 3
Sample Output 1
0 5 1
Hint
对于 30%的数据, 满足 n≤50, m≤50
对于 60%的数据, 满足 n≤1000, m≤1000
对于 100%的数据, 满足 n≤100000, m≤100000
思路:如果存在点到x,y距离相等,这个点一定是中点,或在中点的其他子树上
写法思路:
首先特判if(x==y) ans=n
然后用最近公共祖先LCA算法计算两点间的距离dis(即路径上的树边数),如果dis是奇数则中点不在节点上,ans=0
dis为偶数则可以找到中点。从两点中深度较大的那个点(设这个点是x)向上爬dis/2个距离找到中点mid。
容易想到ans=n-size[x所在子树]-size[y所在子树],但是受建树时根节点不同的影响,y可能不在mid的子树里,而在mid的父亲那条分支里
这时候ans=size[mid]-size[x所在子树]
用if(mid==lca(x,y))为真判定y在mid的子树里(想一想,为什么),之后就没什么细节了。
代码:
#include<iostream> #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> #define maxn 100005 #define maxm 200005 #define id(x) ((x+1)>>1) using namespace std; int fir[maxn], ne[maxm], to[maxm], np; void add(int x,int y){ ne[++np] = fir[x]; fir[x] = np; to[np] = y; } int dep[maxn], fa[maxn][20], siz[maxn]; void dfs(int u,int f,int d){ dep[u] = d; siz[u] = 1; fa[u][0] = f; for(int k = 1; k <= 18; k++){ int j = fa[u][k-1]; fa[u][k] = fa[j][k-1]; } for(int i = fir[u]; i; i=ne[i]){ int v = to[i]; if(v != f) dfs(v, u, d+1), siz[u] += siz[v]; } } int jump(int u, int x) { for(int k = 18; k >= 0; k--) if((1<<k)&x) u = fa[u][k]; return u; } int jump2(int u,int anc){ for(int k = 18; k >= 0; --k) if(dep[fa[u][k]] > dep[anc]) u = fa[u][k]; return u; } int LCA(int x,int y){ x = jump(x, dep[x] - dep[y]); if(x == y) return x; for(int k = 18; k >= 0; k--) if(fa[x][k] != fa[y][k]) x = fa[x][k], y = fa[y][k]; return fa[x][0]; } int n, m; void data_in() { memset(fir, 0, sizeof(fir)); np = 0; int u, v; scanf("%d", &n); for(int i = 1; i < n; ++i) { scanf("%d%d", &u, &v); add(u, v); add(v, u); } } void solve() { dfs(1, 0, 1); int u, v, mid, dis, lca; scanf("%d", &m); while(m--) { scanf("%d%d", &u, &v); if(u==v)printf("%d\n", n); else{ if(dep[u] < dep[v]) swap(u, v); dis = dep[u] + dep[v] - 2*dep[lca = LCA(u, v)]; if(dis%2 == 0) { mid = jump(u, dis/2); u = jump2(u, mid); if(mid == lca){ v = jump2(v, mid); printf("%d\n", n - siz[u] - siz[v]); } else printf("%d\n", siz[mid] - siz[u]); } else printf("0\n"); } } } int main(){ data_in(); solve(); return 0; }
原文地址:https://www.cnblogs.com/de-compass/p/11521247.html