看了题解
题目大意
略
题解
考虑一下:对于一个集合 \(S\),如何计算 \(S\) 中的人必须在 \(1\) 号之后被杀(不在集合内的也可能在 \(1\) 之后被杀)的概率 \(P(S)\)。
令
\[
A = \sum_{i} w_i \B = \sum_{i \in S} w_i
\]
那么
\[
P(S) = \sum_{i=0}^{\infty} {(\frac{A-B-w_1}{A})^i} \cdot \frac{w_1}{A} \= w_1 \cdot \frac{1}{B+w_1}
\]
容斥一下,要求没有人在 \(1\) 之后被杀的概率,答案为
\[
w_1 \cdot \sum_{S} (-1)^{|S|} \frac{1}{B + w_1}
\]
这个可以 DP 计算,令 \(dp[i][j]\) 为考虑了前 \(i\) 个人,集合的和为 \(j\) 的方案数,转移有两种:
\[
dp[i][j] \rightarrow dp[i+1][j] \-dp[i][j] \rightarrow dp[i+1][j+w[i]]
\]
这个可以看成一堆形如 \((1 - x^{w_i})\) 的多项式连乘,开个堆每次拿次数小的一个一个合并就好了,复杂度好像是 \(O(N \lg N)\)?
实现
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
const int MD = 998244353;
const int G = 3;
int pow_mod(int x, int n) {
int r = 1;
while (n) {
if (n & 1) r = ll(r) * x % MD;
x = ll(x) * x % MD;
n >>= 1;
}
return r;
}
namespace ntt {
const int S = 17;
const int NN = 1 << S;
static int rev[NN];
void init() {
for (int i = 1; i < NN; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (S - 1));
}
}
void fft(vector<int> &a) {
int N = int(a.size());
if (N == 1) return;
int s = S - __builtin_ctz(N);
for (int i = 0; i < N; i++) {
int r = rev[i] >> s;
if (i < r) swap(a[i], a[r]);
}
vector<int> root(N / 2);
root[0] = 1;
for (int k = 1; k < N; k *= 2) {
root[1] = pow_mod(G, (MD - 1) / (k << 1));
for (int i = 2; i < k; i++) {
root[i] = ll(root[i - 1]) * root[1] % MD;
}
for (int i = 0; i < N; i += 2 * k) {
for (int j = 0; j < k; j++) {
int d = ll(root[j]) * a[i + j + k] % MD;
a[i + j + k] = a[i + j] + MD - d;
if (a[i + j + k] >= MD) a[i + j + k] -= MD;
a[i + j] = a[i + j] + d - MD;
if (a[i + j] < 0) a[i + j] += MD;
}
}
}
}
vector<int> pmul(const vector<int> &a, const vector<int> &b) {
int A = int(a.size()), B = int(b.size());
int N = 1;
while (N < A + B - 1) N *= 2;
vector<int> ac(N), bc(N);
for (int i = 0; i < A; i++) ac[i] = a[i];
for (int i = 0; i < B; i++) bc[i] = b[i];
fft(ac);
fft(bc);
int ni = pow_mod(N, MD - 2);
for (int i = 0; i < N; i++) {
ac[i] = ll(ac[i]) * bc[i] % MD * ni % MD;
}
reverse(ac.begin() + 1, ac.end());
fft(ac);
ac.resize(A + B - 1);
return ac;
}
}
using ntt::pmul;
int n;
vector< vector<int> > p;
vector<int> w;
int main() {
cin.tie(0);
ios::sync_with_stdio(false);
ntt::init();
cin >> n;
p = vector< vector<int> >(n);
w = vector<int>(n);
int sw = 0;
for (int i = 0; i < n; i++) {
cin >> w[i];
sw += w[i];
}
priority_queue< P, vector<P>, greater<P> > que;
for (int i = 1; i < n; i++) {
p[i] = vector<int>(w[i] + 1);
p[i][0] = 1;
p[i][w[i]] = MD - 1;
que.push(P(w[i] + 1, i));
}
while (int(que.size()) > 1) {
int i = que.top().second; que.pop();
int j = que.top().second; que.pop();
p[i] = pmul(p[i], p[j]);
que.push(P(int(p[i].size()), i));
vector<int>().swap(p[j]);
}
vector<int> f = p[que.top().second];
int ans = 0;
for (int i = 0; i < int(f.size()); i++) {
int inv = pow_mod(i + w[0], MD - 2);
ans += ll(inv) * f[i] % MD;
if (ans >= MD) ans -= MD;
}
cout << ll(ans) * w[0] % MD << ‘\n‘;
return 0;
}
原文地址:https://www.cnblogs.com/hfccccccccccccc/p/10160918.html
时间: 2024-10-08 14:24:05