离线处理所有请求对于每一个点,求这个点到根的[a, b]的总和,即比b+1小的总和减去比a小的总和,因此对于每一个请求我们将其分为三个查询,分别为s的总和,t的总和,以及lca(s,t)的总和,容斥一下就可以得到s到t的总和,事先将这些点存储在一个数组里,dfs到这个点的时候我们用Treap动态维护一下,用Treap查询比某个数小的总和。
#include <cstdio> #include <cstring> #include <map> #include <vector> #include <algorithm> #include <queue> #include <time.h> #include <iostream> using namespace std; const int maxn = 100005; const int maxm = 20; int c[maxn]; vector<int> G[maxn]; struct Node { Node *ch[2]; int r, v, num; long long s; int cmp(int x) const { if(x == v) return -1; return x < v ? 0 : 1; } void maintain() { s = 1LL * v * num; if(ch[0] != NULL) s += ch[0]->s; if(ch[1] != NULL) s += ch[1]->s; } }; Node *root; struct Edge { int to, next; }edge[maxn * 2]; int head[maxn], tot; void addedge(int u, int v) { edge[tot].to = v; edge[tot].next = head[u]; head[u] = tot++; } void init() { tot = 0; memset(head, -1, sizeof(head)); } int fa[maxn][maxm]; int deg[maxn]; void BFS(int root) { queue<int> que; deg[root] = 0; fa[root][0] = root; que.push(root); while(!que.empty()) { int tmp = que.front(); que.pop(); for(int i = 1; i < maxm; ++i) { fa[tmp][i] = fa[fa[tmp][i - 1]][i - 1]; } for(int i = head[tmp]; i != -1; i = edge[i].next) { int v = edge[i].to; if(v == fa[tmp][0]) continue; deg[v] = deg[tmp] + 1; fa[v][0] = tmp; que.push(v); } } } int LCA(int u, int v) { if(deg[u] > deg[v]) swap(u, v); int hu = deg[u], hv = deg[v]; int tu = u, tv = v; for(int det = hv - hu, i = 0; det; det >>= 1, ++i) if(det & 1) tv = fa[tv][i]; if(tu == tv) return tu; for(int i = maxm - 1; i >= 0; --i) { if(fa[tu][i] == fa[tv][i]) continue; tu = fa[tu][i]; tv = fa[tv][i]; } return fa[tu][0]; } void Rotate(Node* &o, int d) { Node *k = o->ch[d ^ 1]; o->ch[d ^ 1] = k->ch[d]; k->ch[d] = o; o->maintain(); k->maintain(); o = k; } void Insert(Node* &o, int x) { if(o == NULL) { o = new Node(); o->ch[0] = o->ch[1] = NULL; o->v = x; o->r = rand(); o->s = o->v; o->num = 1; } else { int d = o->cmp(x); if(d == -1) { o->s += x; o->num += 1; return ; } Insert(o->ch[d], x); if(o->ch[d]->r > o->r) Rotate(o, d ^ 1); } o->maintain(); } void Remove(Node* &o, int x) { int d = o->cmp(x); if(d == -1) { if(o->num != 1) { o->num -= 1; o->s -= x; return ; } Node *u = o; if(o->ch[0] == NULL) {o = o->ch[1]; delete u;} else if(o->ch[1] == NULL) {o = o->ch[0]; delete u;} else { int d2 = (o->ch[0]->r > o->ch[1]->r ? 1 : 0); Rotate(o, d2); Remove(o->ch[d2], x); } } else Remove(o->ch[d], x); if(o != NULL) o->maintain(); } long long Rank(Node *u, int x) { if(u == NULL) return 0; else if(x == u->v) { return u->ch[0] == NULL ? 0 : u->ch[0]->s; } else { int d = u->cmp(x); if(d == 0) return Rank(u->ch[d], x); else return (u->ch[0] == NULL ? 0 : u->ch[0]->s) + 1LL * u->v * u->num + Rank(u->ch[d], x); } } void print(Node *u, int d) { if(u == NULL) return ; cout << "d == " << d << endl; cout << "tmp_v == " << u->v << endl; cout << "tmp_s == " << u->s << endl; print(u->ch[0], 0); print(u->ch[1], 0); } void del(Node *u) { if(u == NULL) return ; if(u->ch[0] != NULL) { del(u->ch[0]); } if(u->ch[1] != NULL) { del(u->ch[1]); } delete u; } struct Item { int a, b, v, id; Item(int a, int b, int v, int id) : a(a), b(b), v(v), id(id) { } }; vector<Item> que[maxn]; long long ans[maxn]; void dfs(int u, int father) { Insert(root, c[u]); // print(root, 0); for(int i = 0; i < que[u].size(); ++i) { int a = que[u][i].a; int b = que[u][i].b; int id = que[u][i].id; int v = que[u][i].v; ans[id] += v * (Rank(root, b + 1) - Rank(root, a)); if(c[u] <= b && c[u] >= a && v < 0) ans[id] += c[u]; // cout << Rank(root, b + 1) << " " << Rank(root, a) << endl; // cout << "u == " << u << endl; // cout << "id == " << id << endl; // cout << "ans == " << ans[id] << endl; } for(int i = 0; i < G[u].size(); ++i) { int v = G[u][i]; if(v != father) dfs(v, u); } Remove(root, c[u]); } int main() { // freopen("1002.in", "r", stdin); // freopen("233.out", "w", stdout); int n, m; while(~scanf("%d%d", &n, &m)) { init(); for(int i = 1; i <= n; ++i) { G[i].clear(); que[i].clear(); } for(int i = 1; i <= n; ++i) { scanf("%d", &c[i]); } for(int i = 0; i < n - 1; ++i) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); addedge(u, v); addedge(v, u); } BFS(1); for(int i = 0; i < m; ++i) { int s, t, a, b; scanf("%d%d%d%d", &s, &t, &a, &b); int l = LCA(s, t); que[l].push_back(Item(a, b, -2, i)); que[s].push_back(Item(a, b, 1, i)); que[t].push_back(Item(a, b, 1, i)); // cout << "l == " << l << " s == " << s << " t == " << t << endl; } memset(ans, 0, sizeof(ans)); root = NULL; dfs(1, -1); del(root); for(int i = 0; i < m; ++i) { printf("%I64d%c", ans[i], " \n"[i == m - 1]); } } return 0; } /* 10 10 241 3873 7875 8445 7001 3861 245 1641 3277 2790 2 1 3 2 4 1 5 3 6 5 7 3 8 7 9 5 10 6 9 5 1213 8766 1 5 1776 9931 5 2 5099 7343 4 1 969 6636 10 5 5361 5985 3 6 6963 8056 6 1 2159 8721 5 5 3843 6771 7 5 2009 7985 1 7 4797 4861 */
时间: 2024-10-26 19:50:40