「ZJOI2019」Minimax 搜索(动态dp)

Address

loj3044

Solution

考虑对 \(k\in [l-1,r]\) 分别求出有多少个集合 \(S\) 满足 \(w(S)\le k\),记作 \(ans_k\)。

先求出 \(1\) 的初始权值 \(W\)。

记 \(val(x)\) 表示 \(x\) 的权值。枚举 \(k\),现在对于每个叶子 \(u\),如果 \(u\in S\),那么 \(val(u)\in [u-W,u+W]\),否则 \(val(u)=W\)。

我们发现,把叶子节点的权值改成 \(W\) 肯定是不优的。所以改动一些叶子后,如果 \(val(1)\) 还是 \(W\),那么肯定路径 \(1→W\) 上每个点的权值都是 \(W\),且其它的点的权值都不是 \(W\)。

因此,如果想要 \(val(1)\) 改变,那么路径 \(1→W\) 上肯定存在一个点 \(x\),\(val(x)\ne W\)。记 \(x\) 在路径 \(1→W\) 上的子节点为 \(y\)。如果 \(x\) 深度是奇数, 那么肯定存在一个 \(x\) 的子节点 \(z(z\ne y)\),\(val(z)>W\)。\(x\) 深度是偶数时同理。

我们把 \(1→W\) 上的边全部断掉,再求一遍每个点的权值。如果原路径 \(1→W\) 上存在某个深度为奇数的点的权值 \(>W\),或者某个深度为偶数的点的权值 \(<W\),那么 \(val(1)\) 肯定改变,否则肯定不变。

记 \(f(u)\) 表示 \(u\) 子树中,使 \(val(u)>w\) 的合法叶子节点集合有几个。\(g(u)\) 表示 \(u\) 子树中,使 \(val(u)<w\) 的合法叶子节点集合有几个。

如果 \(u\) 是叶子节点:\(f(u)=[u>W]+[u+k>W],g(u)=[u<W]+[u-k<W]\)。其中 \([u>W],[u<W]\) 表示 \(u\) 不在叶子节点集合内,\([u+k>W],[u-k<W]\) 表示在集合内。

如果 \(u\) 是深度为奇数的非叶子节点,如果 \(val(u)>W\),那么 \(u\) 的子节点最大权值必须 \(>W\),也就是说不能全部 \(\le W\)。因此 \(f(u)=2^{cnt_u}\prod_{v\in son_u}(2^{cnt_v}-f(v))\)。其中 \(cnt_u\) 表示 \(u\) 的子树内有几个叶子节点。

如果 \(u\) 是深度为偶数的非叶子节点,如果 \(val(u)>W\),那么 \(u\) 的子节点全部 \(<W\)。因此 \(f(u)=\prod_{v\in son_u}f(v)\)。

\(g\) 的转移和 \(f\) 类似。

接下来求 \(ans_k\)。考虑补集转化,即用 \(2^{cnt_1}\) 减去不会让 \(val(1)\) 改变的集合数。不会让 \(val(1)\) 改变,就是要让原路径 \(1→W\) 上的每个点的权值都不变。那么把深度为奇数的 \(2^{cnt_x}-f_x\) 和深度为偶数的 \(2^{cnt_x}-g_x\) 全部相乘就是答案了。

至此,我们得到了一个 \(O(n(R-L))\) 的做法。

考虑优化,我们发现转移与 \(k\) 无关,只有叶子节点的 \(f,g\) 和 \(k\) 有关。进一步地,我们发现随着 \(k\) 变大,每个叶子节点的 \(f,g\) 最多改变一次。因此可以看作是 \(O(n)\) 次修改的动态 \(dp\),时间复杂度 \(O(n\log^2 n)\)。

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long
#define p2 p << 1
#define p3 p << 1 | 1

template <class t>
inline void read(t & res)
{
    char ch;
    while (ch = getchar(), !isdigit(ch));
    res = ch ^ 48;
    while (ch = getchar(), isdigit(ch))
    res = res * 10 + (ch ^ 48);
}

template <class t>
inline void print(t x)
{
    if (x > 9) print(x / 10);
    putchar(x % 10 + 48);
}

const int e = 2e5 + 5, mod = 998244353;

struct point
{
    int x, y;
}b[e], que[e];
struct matrix
{
    int a, b;

    matrix(){}
    matrix(int _a, int _b) :
        a(_a), b(_b) {}
}tr[e << 2];
vector<int>g[e], c[e], d[e];
int f[e], dep[e], L, R, w, n, fa[e], a[e], m, nxt[e], go[e], adj[e], val[e], K, cnt[e], f2[e];
int q[e], h[e], num, all, sum[e << 2], son[e], sze[e], dfnA[e], dfnB[e], timA, timB, idA[e], idB[e];
int st[e], ed[e], bot[e], top[e], ans[e], rt[e], now_rt;
bool is[e], op, bo[e];

inline void add(int &x, int y)
{
    (x += y) >= mod && (x -= mod);
}

inline void del(int &x, int y)
{
    (x -= y) < 0 && (x += mod);
}

inline int plu(int x, int y)
{
    add(x, y);
    return x;
}

inline int sub(int x, int y)
{
    del(x, y);
    return x;
}

inline int mul(int x, int y)
{
    return (ll)x * y % mod;
}

inline int ksm(int x, int y)
{
    int res = 1;
    while (y)
    {
        if (y & 1) res = mul(res, x);
        y >>= 1;
        x = mul(x, x);
    }
    return res;
}

inline matrix operator + (matrix u, matrix v)
{
    return matrix(mul(u.a, v.a), plu(mul(u.b, v.a), v.b));
}

inline void link1(int x, int y)
{
    g[x].push_back(y);
    g[y].push_back(x);
}

inline void link2(int x, int y)
{
    nxt[++num] = adj[x]; adj[x] = num; go[num] = y;
}

inline void dfs1(int u, int pa)
{
    dep[u] = dep[pa] + 1;
    fa[u] = pa;
    if (dep[u] & 1) val[u] = 0;
    else val[u] = n + 1;
    int len = g[u].size(), i;
    bool pd = 0;
    for (i = 0; i < len; i++)
    {
        int v = g[u][i];
        if (v == pa) continue;
        pd = 1;
        dfs1(v, u);
        if (dep[u] & 1) val[u] = max(val[u], val[v]);
        else val[u] = min(val[u], val[v]);
    }
    if (!pd) val[u] = u, all++;
}

inline void dfs2(int u)
{
    if (val[u] == u)
    {
        if (op)
        {
            f[u] = (u > w) + (u + K > w);
            if (L <= w + 1 - u && w + 1 - u <= R) c[w + 1 - u].push_back(u);
        }
        else
        {
            f[u] = (u < w) + (u - K < w);
            if (L <= u + 1 - w && u + 1 - w <= R) d[u + 1 - w].push_back(u);
        }
        return;
    }
    f[u] = f2[u] = 1;
    bool fl = ((dep[u] & 1) && op) || ((~dep[u] & 1) && !op);
    bo[u] = fl;
    for (int i = adj[u]; i; i = nxt[i])
    {
        int v = go[i];
        dfs2(v);
        if (fl) f[u] = mul(f[u], sub(q[v], f[v]));
        else f[u] = mul(f[u], f[v]);
        if (v != son[u])
        {
            if (fl) f2[u] = mul(f2[u], sub(q[v], f[v]));
            else f2[u] = mul(f2[u], f[v]);
        }
    }
    if (fl) f[u] = sub(q[u], f[u]);
}

inline void dfs3(int u)
{
    if (val[u] == u) cnt[u] = 1;
    sze[u] = 1;
    rt[u] = now_rt;
    for (int i = adj[u]; i; i = nxt[i])
    {
        int v = go[i];
        dfs3(v);
        cnt[u] += cnt[v];
        sze[u] += sze[v];
        if (sze[v] > sze[son[u]]) son[u] = v;
    }
}

inline void dfs4(int u, int fi)
{
    top[u] = fi;
    dfnA[u] = ++timA;
    idA[timA] = u;
    if (son[u])
    {
        dfs4(son[u], fi);
        st[u] = timB + 1;
        for (int i = adj[u]; i; i = nxt[i])
        {
            int v = go[i];
            if (v == son[u]) continue;
            dfnB[v] = ++timB;
            idB[timB] = v;
        }
        ed[u] = timB;
    }
    for (int i = adj[u]; i; i = nxt[i])
    {
        int v = go[i];
        if (v == son[u]) continue;
        dfs4(v, v);
    }
    if (son[u]) bot[u] = bot[son[u]];
    else bot[u] = u;
}

inline void build(int l, int r, int p)
{
    if (l == r)
    {
        int u = idA[l], v = idB[l];
        if (son[u])
        {
            if (bo[u])
            {
                int v = son[u];
                tr[p] = matrix(f2[u], sub(q[u], mul(f2[u], q[v])));
            }
            else tr[p] = matrix(f2[u], 0);
        }
        if (v)
        {
            int pa = fa[v];
            if (bo[pa]) sum[p] = sub(q[v], f[v]);
            else sum[p] = f[v];
        }
        return;
    }
    int mid = l + r >> 1;
    build(l, mid, p2);
    build(mid + 1, r, p3);
    tr[p] = tr[p3] + tr[p2];
    sum[p] = mul(sum[p2], sum[p3]);
}

inline void upt_tr(int l, int r, int s, matrix u, int p)
{
    if (l == r)
    {
        tr[p] = u;
        return;
    }
    int mid = l + r >> 1;
    if (s <= mid) upt_tr(l, mid, s, u, p2);
    else upt_tr(mid + 1, r, s, u, p3);
    tr[p] = tr[p3] + tr[p2];
}

inline void upt_sum(int l, int r, int s, int v, int p)
{
    if (l == r)
    {
        sum[p] = v;
        return;
    }
    int mid = l + r >> 1;
    if (s <= mid) upt_sum(l, mid, s, v, p2);
    else upt_sum(mid + 1, r, s, v, p3);
    sum[p] = mul(sum[p2], sum[p3]);
}

inline matrix ask_tr(int l, int r, int s, int t, int p)
{
    if (l == s && r == t) return tr[p];
    int mid = l + r >> 1;
    if (t <= mid) return ask_tr(l, mid, s, t, p2);
    else if (s > mid) return ask_tr(mid + 1, r, s, t, p3);
    else return ask_tr(mid + 1, r, mid + 1, t, p3) + ask_tr(l, mid, s, mid, p2);
}

inline int ask_sum(int l, int r, int s, int t, int p)
{
    if (l == s && r == t) return sum[p];
    int mid = l + r >> 1;
    if (t <= mid) return ask_sum(l, mid, s, t, p2);
    else if (s > mid) return ask_sum(mid + 1, r, s, t, p3);
    else return mul(ask_sum(l, mid, s, mid, p2), ask_sum(mid + 1, r, mid + 1, t, p3));
}

inline void pair_mul(point &u, int x)
{
    if (!x) u.y++;
    else u.x = mul(u.x, x);
}

inline void pair_div(point &u, int x)
{
    if (!x) u.y--;
    else u.x = mul(u.x, ksm(x, mod - 2));
}

inline void cover(int &x, point u)
{
    int res = u.x;
    if (u.y) res = 0;
    x = sub(all, res);
}

inline int calc(int x, matrix c)
{
    return plu(mul(x, c.a), c.b);
}

inline int ask(int x)
{
    if (x == bot[x]) return f[x];
    int l = dfnA[x], r = dfnA[bot[x]] - 1;
    return calc(f[bot[x]], ask_tr(1, n, l, r, 1));
}

inline void change(int x)
{
    pair_div(que[K], sub(q[rt[x]], f[rt[x]]));
    x = top[x];
    while (x)
    {
        f[x] = ask(x);
        if (!fa[x]) break;
        int y = fa[x];
        if (bo[y]) upt_sum(1, n, dfnB[x], sub(q[x], f[x]), 1);
        else upt_sum(1, n, dfnB[x], f[x], 1);
        f2[y] = ask_sum(1, n, st[y], ed[y], 1);

        matrix tmp;
        if (bo[y])
        {
            int v = son[y];
            tmp = matrix(f2[y], sub(q[y], mul(f2[y], q[v])));
        }
        else tmp = matrix(f2[y], 0);
        upt_tr(1, n, dfnA[y], tmp, 1);

        x = top[y];
    }
    pair_mul(que[K], sub(q[x], f[x]));
}

int main()
{
    freopen("minimax.in", "r", stdin);
    freopen("minimax.out", "w", stdout);
    read(n); read(L); read(R);
    int i, x, y, j;
    for (i = 1; i < n; i++) read(x), read(y), link1(x, y), b[i].x = x, b[i].y = y;
    dfs1(1, 0);
    x = w = val[1];
    h[0] = 1;
    for (i = 1; i <= n; i++) h[i] = plu(h[i - 1], h[i - 1]);
    while (x != 1)
    {
        a[++m] = x;
        x = fa[x];
    }
    a[++m] = 1;
    reverse(a + 1, a + m + 1);
    for (i = 1; i <= m; i++) is[a[i]] = 1;
    for (i = 1; i < n; i++)
    {
        x = b[i].x; y = b[i].y;
        if (!is[x] || !is[y])
        {
            if (fa[x] == y) link2(y, x);
            else link2(x, y);
        }
    }
    for (i = 1; i <= m; i++) now_rt = a[i], dfs3(a[i]);
    all = h[all];
    for (i = 1; i <= n; i++) q[i] = h[cnt[i]];
    for (i = 1; i <= m; i++) dfs4(a[i], a[i]), fa[a[i]] = 0;

    bool flag = 0;
    if (L == 1) K = L, flag = 1, L++;
    else K = L - 1;
    que[K].x = 1;
    for (j = 1; j <= m; j++)
    {
        int u = a[j];
        op = j & 1; dfs2(u);
        pair_mul(que[K], sub(q[u], f[u]));
    }
    cover(ans[K], que[K]);
    build(1, n, 1);
    for (i = L; i <= R; i++)
    {
        que[i] = que[i - 1];
        K = i;
        int lenc = c[i].size(), lend = d[i].size();
        for (j = 0; j < lenc; j++)
        {
            int u = c[i][j];
            f[u] = (u > w) + (u + K > w);
            change(u);
        }
        for (j = 0; j < lend; j++)
        {
            int u = d[i][j];
            f[u] = (u < w) + (u - K < w);
            change(u);
        }
        if (i == n)
        {
            ans[i] = sub(all, 1);
            continue;
        }
        cover(ans[i], que[i]);
    }
    if (flag) L--;
    for (i = L; i <= R; i++)
    print(sub(ans[i], ans[i - 1])), putchar(i == R ? '\n' : ' ');
    return 0;
}

原文地址:https://www.cnblogs.com/cyf32768/p/12296954.html

时间: 2024-11-08 23:54:45

「ZJOI2019」Minimax 搜索(动态dp)的相关文章

loj2537 「PKUWC2018」Minimax 【概率 + 线段树合并】

题目链接 loj2537 题解 观察题目的式子似乎没有什么意义,我们考虑计算出每一种权值的概率 先离散化一下权值 显然可以设一个\(dp\),设\(f[i][j]\)表示\(i\)节点权值为\(j\)的概率 如果\(i\)是叶节点显然 如果\(i\)只有一个儿子直接继承即可 如果\(i\)有两个儿子,对于儿子\(x\),设另一个儿子为\(y\) 则有 \[f[i][j] += f[x][j](1 - p_i)\sum\limits_{k > j}f[r][k] + f[x][j]p_i\sum\

loj2537. 「PKUWC2018」Minimax

题意 略. 题解 首先设\(f_{x, c}\)表示以\(x\)为根的子树内,最终取到了\(c\)的概率.可以列出转移方程(假设有两个孩子\(u, v\)) \[ \begin{aligned} f_{x, c} = & f_{u, c} * (p * v子树中最终权值小于c的概率 + (1 - p) * v子树中最终权值大于c的概率) \+ & f_{v, c} * (p * u子树中最终权值小于c的概率 + (1 - p) * u子树中最终权值大于c的概率) \\end{aligned

LG2145 「JSOI2007」祖码 区间DP

问题描述 LG2145 题解 把颜色相同的一段看做一个点. 然后类似于合唱队区间DP即可. 但是这题好像出过一些情况,导致我包括题解区所有人需要特判最后一个点. \(\mathrm{Code}\) #include<bits/stdc++.h> using namespace std; template <typename Tp> void read(Tp &x){ x=0;char ch=1;int fh; while(ch!='-'&&(ch<'0

LG4158 「SCOI2009」粉刷匠 线性DP

问题描述 LG4158 题解 设\(opt[i][j][k]\)代表到\((i,k)\)刷了\(j\)次的方案数. 一开始DP顺序有点问题,调了很长时间. 务必考虑清楚DP顺序问题 \(\mathrm{Code}\) #include<bits/stdc++.h> using namespace std; template <typename Tp> void read(Tp &x){ x=0;char ch=1;int fh; while(ch!='-'&&

loj#2537. 「PKUWC2018」Minimax

传送门 感觉我去pkuwc好像只有爆零的份-- 设\(f_{u,i}\)表示\(u\)取到\(i\)的概率,那么有如下转移 \[f_{u,i}=f_{ls,i}(p_u\sum_{j<i}f_{rs,j}+(1-p_u)\sum_{j>i}f_{rs,j})+\\f_{rs,i}(p_u\sum_{j<i}f_{ls,j}+(1-p_u)\sum_{j>i}f_{ls,j})\] 然后用线段树合并即可,最后在根节点的线段树上\(dfs\)统计答案 //minamoto #inclu

「PKUWC2018」Minimax

传送门 Solution 发现叶子节点的值都不样,所以可以线段树合并. 然后因为我们要维护一个后缀,所以我们先合并右儿子,在合并左儿子 Code? //2019.1.14 8:59~10:15 PaperCloud #include<bits/stdc++.h> #define ll long long #define max(a,b) ((a)>(b)?(a):(b)) #define min(a,b) ((a)<(b)?(a):(b)) inline int read() {

「ZJOI2019」线段树

传送门 Description 线段树的核心是懒标记,下面是一个带懒标记的线段树的伪代码,其中 tag 数组为懒标记: 其中函数\(Lson(Node)\)表示\(Node\)的左儿子,\(Rson(Node)\)表示\(Node\)的右儿子. 有一棵 \([1,n]\)上的线段树,编号为\(1\) .初始时什么标记都没有. 每次修改会把当前所有的线段树复制一份,然后对于这些线段树实行一次区间修改操作. 每次修改后线段树棵数翻倍,第 \(i\)次修改后,线段树共有 \(2^i\) 棵. 每次询问

「校内训练 2019-04-23」越野赛车问题 动态dp+树的直径

题目传送门 http://192.168.21.187/problem/1236 http://47.100.137.146/problem/1236 题解 题目中要求的显然是那个状态下的直径嘛. 所以这道题有一个非常简单的做法--线段树分治. 直接把每一条边按照 \(l, r\) 的区间放到线段树上进行分治,遍历的时候用并查集维护直径就可以了. 时间复杂度为 \(O(n\log^2n)\). 很早以前就写了这个算法,代码附在了最后,不多讲了. 但是这道题还有一个方法--动态 DP. 线段树分治

「CQOI2011」动态逆序对

「CQOI2011」动态逆序对 传送门 树套树. 删除一个位置的元素带来的减损数等于他前面大于它的和后面小于它的,然后这个直接树状数组套主席树维护一下就好了. 参考代码: #include <cstdio> #define rg register #define file(x) freopen(x".in", "r", stdin), freopen(x".out", "w", stdout) template &