题外话
最近课程不是很紧,准备按AC率版切bz,争取一天一道题以上。然后我喜闻乐见的发现之前剩下的题基本都是数据结构>_<。蛋疼啊。。。
Description
给定一棵树,每个节点要么是黑色,要么是白色,能执行两个操作:把某一个点取反色,返回距离最远的黑色点对。
Solution
这题看起来链分治,边分治都可做,然后搜到了小岛的题解。发现了逼格更高的做法,看了曹钦翔的《数据结构的提炼与压缩》,跪烂了。。。
这题用到了dfs序的性质,也就是括号序列。。
定义一种对一棵树的括号编码。这种编码方式很直观,所以,这里不给出严格的定义,用以下这棵树为例:
(图片自行脑补。。。)
可以先序遍历后写成:[A[B[E][F[H][I]]][C][D[G]]]
去掉字母后的串:[[[][[][]]][][[]]]
就称为这棵树的括号编码。(这个编码本质上是由深度优先遍历得到的)
考察两个结点,如 E 和 G ,
取出介于它们之间的那段括号编码 :][[][]]][][[
把匹配的括号去掉,得到:]][[
我们看到 2 个 ] 和 2 个 [,也就是说,在树中,从 E 向上爬 2 步,再向下走 2 步就到了 G。
注意到,题目中需要的信息只有这棵树中点与点的距离,所以,贮存编码中匹配的括号是没有意义的。
因此,对于介于两个节点间的一段括号编码 S,可以用一个二元组 (a, b) 描述它,即这段编码去掉匹配括号后有 a 个 ] 和 b 个 [。
所以,对于两个点 PQ,如果介于某两点 PQ 之间编码 S 可表示为 (a, b),PQ 之间的距离就是 a+b。
也就是说,题目只需要动态维护:max{a+b | S’(a, b) 是 S 的一个子串,且 S’ 介于两个黑点之间},
这里 S 是整棵树的括号编码。我们把这个量记为 dis(s)。
现在,如果可以通过左边一半的统计信息和右边一半的统计信息,得到整段编码的统计,这道题就可以用熟悉的线段树解决了。
这需要下面的分析。
考虑对于两段括号编码 S1(a1, b1) 和 S2(a2, b2),如果它们连接起来形成 S(a, b)。
注意到 S1、S2 相连时又形成了 min{b, c} 对成对的括号,合并后它们会被抵消掉。(?..这里 b, c 应该分别是指 b1 和 a2。。。
所以:
当 a2 < b1 时第一段 [ 就被消完了,两段 ] 连在一起,例如:
] ] [ [ + ] ] ] [ [ = ] ] ] [ [
当 a2 >= b1 时第二段 ] 就被消完了,两段 [ 连在一起,例如:
] ] [ [ [ + ] ] [ [ = ] ] [ [ [ (?..反了?。。。
这样,就得到了一个十分有用的结论:
当 a2 < b1 时,(a,b) = (a1-b1+a2, b2),
当 a2 >= b1 时,(a,b) = (a1, b1-a2+b2)。
由此,又得到几个简单的推论:
(i) a+b = a1+b2+|a2-b1| = max{(a1-b1)+(a2+b2), (a1+b1)+(b2-a2)}
(ii) a-b = a1-b1+a2-b2
(iii) b-a = b2-a2+b1-a1
由 (i) 式,可以发现,要维护 dis(s),就必须对子串维护以下四个量:
right_plus:max{a+b | S’(a,b) 是 S 的一个后缀,且 S’ 紧接在一个黑点之后}
right_minus:max{a-b | S’(a,b) 是 S 的一个后缀,且 S’ 紧接在一个黑点之后}
left_plus:max{a+b | S’(a,b) 是 S 的一个前缀,且有一个黑点紧接在 S 之后}
left_minus:max{b-a | S’(a,b) 是 S 的一个前缀,且有一个黑点紧接在 S 之后}
这样,对于 S = S1 + S2,其中 S1(a, b)、S2(c, d)、S(e, f),就有
(e, f) = b < c ? (a-b+c, d) : (a, b-c+d)
dis(S) = max{dis(S1), left_minus(S2)+right_plus(S1), left_plus(S2)+right_minus(S1), dis(S2)}
那么,增加这四个参数是否就够了呢?
是的,因为:
right_plus(S) = max{right_plus(S1)-c+d, right_minus(S1)+c+d, right_plus(S2)}
right_minus(S) = max{right_minus(S1)+c-d, right_minus(S2)}
left_plus(S) = max{left_plus(S2)-b+a, left_minus(S2)+b+a, left_plus(S1)}
left_minus(S) = max{left_minus(S2)+b-a, left_minus(S1)}
这样一来,就可以用线段树处理编码串了。实际实现的时候,在编码串中加进结点标号会更方便,对于底层结点,如果对应字符是一个括号或者一个白点,那 么right_plus、right_minus、left_plus、left_minus、dis 的值就都是 -maxlongint;如果对应字符是一个黑点,那么 right_plus、right_minus、left_plus、left_minus 都是 0,dis 是 -maxlongint。
现在这个题得到圆满解决,回顾这个过程,可以发现用一个串表达整棵树的信息是关键,这一“压”使得线段树这一强大工具得以利用…
然后这个题就做完了,太神了。。。
Code
#include <bits/stdc++.h>
#define ls (rt << 1)
#define rs (rt << 1 | 1)
using namespace std;
const int N = 100010, inf = 1e9;
int n, q, tot, cnt, dfn[N * 3], to[N << 1], nxt[N << 1], head[N], pos[N], cc[N];
struct Node {
int l, r, l1, r1, l2, r2, c1, c2, dis;
void init(int x) {
dis = -inf;
c1 = c2 = 0;
if (dfn[x] == -2) c2 = 1;//(
if (dfn[x] == -5) c1 = 1;//)
if (dfn[x] > 0 && !cc[dfn[x]]) l1 = r1 = l2 = r2 = 0;
else l1 = r1 = l2 = r2 = -inf;
}
}a[N * 12];
inline int read(int &t) {
int f = 1;char c;
while (c = getchar(), c < ‘0‘ || c > ‘9‘) if (c == ‘-‘) f = -1;
t = c - ‘0‘;
while (c = getchar(), c >= ‘0‘ && c <= ‘9‘) t = t * 10 + c - ‘0‘;
t *= f;
}
void add(int u, int v) {
to[tot] = v, nxt[tot] = head[u], head[u] = tot++;
to[tot] = u, nxt[tot] = head[v], head[v] = tot++;
}
void dfs(int u, int fa) {
dfn[++cnt] = -2;//
dfn[++cnt] = u;
pos[u] = cnt;
for (int i = head[u], v; ~i; i = nxt[i]) {
v = to[i];
if (v != fa) dfs(v, u);
}
dfn[++cnt] = -5;//)
}
Node merge(Node a, Node b) {
Node c;
c.l = a.l, c.r = b.r;
c.dis = max(a.dis, b.dis);
c.dis = max(c.dis, max(a.r2 + b.l1, a.r1 + b.l2));
if (b.c1 < a.c2) c.c1 = a.c1, c.c2 = a.c2 - b.c1 + b.c2;
else c.c1 = a.c1 + b.c1 - a.c2, c.c2 = b.c2;
c.l1 = max(a.l1, max(b.l1 + a.c1 - a.c2, b.l2 + a.c1 + a.c2));
c.r1 = max(b.r1, max(a.r1 - b.c1 + b.c2, a.r2 + b.c1 + b.c2));
c.l2 = max(a.l2, b.l2 + a.c2 - a.c1);
c.r2 = max(b.r2, a.r2 + b.c1 - b.c2);
return c;
}
void change(int rt, int p) {
if (a[rt].l == a[rt].r) {
a[rt].init(a[rt].l);
return;
}
int mid = a[rt].l + a[rt].r >> 1;
if (p <= mid) change(ls, p);
else change(rs, p);
a[rt] = merge(a[ls], a[rs]);
}
void build(int rt, int l, int r) {
if (l == r) {
a[rt].l = l, a[rt].r = r;
a[rt].init(l);
return;
}
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
a[rt] = merge(a[ls], a[rs]);
}
int main() {
memset(head, -1, sizeof(head));
read(n);
int now = n;
for (int i = 1, x, y; i < n; ++i) {
read(x), read(y);
add(x, y);
}
dfs(1, 0);
build(1, 1, cnt);
read(q);
while (q--) {
char s[10];
int x;
scanf("%s", s);
if (s[0] == ‘C‘) {
read(x);
if (!cc[x]) --now;
else ++now;
cc[x] ^= 1;
change(1, pos[x]);
}
else {
if (!now) puts("-1");
else if (now == 1) puts("0");
else printf("%d\n", a[1].dis);
}
}
return 0;
}