给出一棵树,许多询问,每次询问A,B,C三点,求一点使到三点距离最小,输出该点和最小值。
很明显就是求LCA,三种组合都求一次LCA,然后在里面选个距离和最小的就行了。
官方题解里面的代码求LCA是在线DFS RMQ的方法..先记录欧拉序,且记录某个点在序列里的第一个位置,每次询问a,b的LCA就是询问两者在欧拉序列里第一个位置之差中的那些点里面深度最小的
LCA(a,b)=RMQ(dep, pos[a], pos[b])
/** @Date : 2017-09-27 20:41:28 * @FileName: CS20 C LCA RMQ.cpp * @Platform: Windows * @Author : Lweleth ([email protected]) * @Link : https://github.com/ * @Version : $Id$ */ #include <bits/stdc++.h> #define LL long long #define PII pair<int ,int> #define MP(x, y) make_pair((x),(y)) #define fi first #define se second #define PB(x) push_back((x)) #define MMG(x) memset((x), -1,sizeof(x)) #define MMF(x) memset((x),0,sizeof(x)) #define MMI(x) memset((x), INF, sizeof(x)) using namespace std; const int INF = 0x3f3f3f3f; const int N = 1e5+20; const double eps = 1e-8; int dep[N]; int pos[N]; int rmq[19][2*N]; int eul[2*N], c, l[2*N]; vector<int> edg[N]; void dfs(int x, int pre) { eul[++c] = x; pos[x] = c; if(pre) dep[x] = dep[pre] + 1; for(auto i: edg[x]) { if(i == pre) continue; dfs(i, x); eul[++c] = x; } } void init() { dfs(1, 0); for(int i = 2; i <= c; i++)//预处理 2^k=x对应的k l[i] = l[i / 2] + 1; for(int i = 1; i <= c; i++) rmq[0][i] = eul[i]; for(int j = 1; (1 << j) <= c; j++) for(int i = 1; i <= c; i++) { rmq[j][i] = rmq[j - 1][i]; if(i + (1 << (j - 1)) > c) continue; if(dep[rmq[j - 1][i + (1 << (j - 1))]] < dep[rmq[j][i]]) rmq[j][i] = rmq[j - 1][i + (1 << (j - 1))]; } } int lca(int x, int y) { if(pos[x] > pos[y]) swap(x, y); int dis = pos[y] - pos[x] + 1; int k = l[dis]; if(dep[rmq[k][pos[x] + dis - (1 << k)]] < dep[rmq[k][pos[x]]]) return rmq[k][pos[x] + dis - (1 << k)]; else return rmq[k][pos[x]]; } int distance(int a, int b) { int ac = lca(a, b); return dep[a] + dep[b] - 2 * dep[ac]; } int main() { int n, q; cin >> n >> q; for(int i = 0; i < n - 1; i++) { int x, y; scanf("%d%d", &x, &y); edg[x].PB(y); edg[y].PB(x); } init(); while(q--) { int a, b, c; scanf("%d%d%d", &a, &b, &c); int ac1 = lca(a, b); int ac2 = lca(a, c); int ac3 = lca(b, c); int ans1 = distance(ac1, a) + distance(ac1, b) + distance(ac1, c); int ans2 = distance(ac2, a) + distance(ac2, b) + distance(ac2, c); int ans3 = distance(ac3, a) + distance(ac3, b) + distance(ac3, c); if(ans1 > ans2) swap(ans1, ans2), swap(ac1, ac2); if(ans1 > ans3) swap(ans1, ans3), swap(ac1, ac3); printf("%d %d\n", ac1, ans1); } return 0; }
时间: 2024-10-29 04:41:24