题目大意:给出一棵树,每次询问一些节点,当把这些节点连接起来,使得每两个询问的点之间有一条边,共有k*(k - 1)条边。问这些边中,长度的总和是多少,最短的一条边是多少,最长的一条边是多少。保证询问的点的总数是O(n)级别。
思路:利用LCA单调性,每次询问的时候重新建树,在这棵树上做DP,使得总体时间复杂度降到O(nlogn)。
树形DP我写的都要麻烦死了。。听了正解之后简直想吐血。。
我的做法是维护四个数组,sum,size,_min,_max,分别表示以当前节点为根节点的子树中的所有关键点到根节点的距离的总和,共有多少个关键点,距离根节点最近的关键点的距离,距离跟节点最远的关键点的距离。此外,在做DP的同时,除了最值还要记录一下次值,用__min和__max表示。记录一个全局变脸来表示最终答案。DP方程(y表示x的一个子树的根节点):
size[x] = ∑size[y] + super[x];
sum[x] = ∑(sum[y] + length * size[y]);
_min[x] = min{_min[y] + length}
_max[x] = max{_max[y] + length}
注意还要更新一下次值
更新答案的表达式:
ans += ∑((sum[y] + length * size[y]) * (size[x] - size[y]));
ans_min = min(ans_min,_min[x] + __min);
ans_max = max(ans_max,_max[x] + __max);
没了。。记得开long long
其实完全不用记录次值,只需要两两子树合并就可以了,然后都取最值。
CODE:
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define MAX 1000010 #define INF (1ll << 60) using namespace std; int points,asks; int head[MAX],total; int next[MAX << 1],aim[MAX << 1]; int pos[MAX],cnt; int deep[MAX],father[MAX][20]; inline void Add(int x,int y) { next[++total] = head[x]; aim[total] = y; head[x] = total; } void DFS(int x,int last) { pos[x] = ++cnt; deep[x] = deep[last] + 1; father[x][0] = last; for(int i = head[x]; i; i = next[i]) { if(aim[i] == last) continue; DFS(aim[i],x); } } inline int GetLCA(int x,int y) { if(deep[x] < deep[y]) swap(x,y); for(int i = 19; ~i ; --i) if(deep[father[x][i]] >= deep[y]) x = father[x][i]; if(x == y) return x; for(int i = 19; ~i; --i) if(father[x][i] != father[y][i]) x = father[x][i],y = father[y][i]; return father[x][0]; } void MakeTable() { for(int j = 1; j <= 19; ++j) for(int i = 1; i <= points; ++i) father[i][j] = father[father[i][j - 1]][j - 1]; } long long ans_min,ans_max,ans; struct Graph{ int head[MAX],v[MAX],T,total; int next[MAX],aim[MAX]; long long length[MAX]; long long size[MAX],sum[MAX]; long long _min[MAX],_max[MAX]; int super[MAX]; void Reset() { total = 0; ++T; } void Set(int x) { super[x] = T; } inline void Add(int x,int y,long long len) { if(v[x] != T) { v[x] = T; head[x] = 0; } next[++total] = head[x]; aim[total] = y; length[total] = len; head[x] = total; } void TreeDP(int x) { size[x] = 0; sum[x] = 0; _min[x] = super[x] == T ? 0:INF; _max[x] = super[x] == T ? 0:-INF; long long __min = INF,__max = -INF; if(v[x] != T) v[x] = T,head[x] = 0; for(int i = head[x]; i; i = next[i]) { TreeDP(aim[i]); size[x] += size[aim[i]]; sum[x] += sum[aim[i]] + length[i] * size[aim[i]]; if(_max[aim[i]] + length[i] > _max[x]) { __max = _max[x]; _max[x] = _max[aim[i]] + length[i]; } else if(_max[aim[i]] + length[i] > __max) __max = _max[aim[i]] + length[i]; if(_min[aim[i]] + length[i] < _min[x]) { __min = _min[x]; _min[x] = _min[aim[i]] + length[i]; } else if(_min[aim[i]] + length[i] < __min) __min = _min[aim[i]] + length[i]; } for(int i = head[x]; i; i = next[i]) { ans += (sum[aim[i]] + length[i] * size[aim[i]]) * (size[x] - size[aim[i]]); if(super[x] == T) ans += (sum[aim[i]] + length[i] * size[aim[i]]); } ans_min = min(ans_min,_min[x] + __min); ans_max = max(ans_max,_max[x] + __max); size[x] += super[x] == T; } }graph; struct Complex{ int x,pos; Complex(int _,int __):x(_),pos(__) {} Complex() {} bool operator <(const Complex &a)const { return pos < a.pos; } }src[MAX]; int stack[MAX]; int main() { cin >> points; for(int x,y,i = 1; i < points; ++i) { scanf("%d%d",&x,&y); Add(x,y),Add(y,x); } DFS(1,0); MakeTable(); cin >> asks; for(int cnt,j = 1; j <= asks; ++j) { scanf("%d",&cnt); for(int i = 1; i <= cnt; ++i) scanf("%d",&src[i].x),src[i].pos = pos[src[i].x]; sort(src + 1,src + cnt + 1); graph.Reset(); int top = 0; stack[++top] = 1; for(int i = 1; i <= cnt; ++i) { int lca = GetLCA(stack[top],src[i].x); while(deep[stack[top]] > deep[lca]) { if(deep[stack[top - 1]] <= deep[lca]) { int away = stack[top--]; if(stack[top] != lca) stack[++top] = lca; graph.Add(stack[top],away,abs(deep[stack[top]] - deep[away])); break; } graph.Add(stack[top - 1],stack[top],abs(deep[stack[top - 1]] - deep[stack[top]])),--top; } if(stack[top] != src[i].x) stack[++top] = src[i].x; graph.Set(src[i].x); } while(top != 1) graph.Add(stack[top - 1],stack[top],abs(deep[stack[top - 1]] - deep[stack[top]])),--top; ans_min = INF,ans_max = 0,ans = 0; graph.TreeDP(1); printf("%lld %lld %lld\n",ans,ans_min,ans_max); } return 0; }
时间: 2024-10-11 11:34:57