P5644 [PKUWC2018]猎人杀
题目大意
一开始有
n
n
n个猎人,第
i
i
i个猎人有仇恨度
w
i
w_i
wi。每次可以开枪射杀一个活着的猎人。
假设活着的猎人为
i
1
,
i
2
,
…
,
i
m
i_1,i_2,\dots,i_m
i1,i2,…,im,则第
i
k
i_k
ik个猎人被射杀的概率是
w
i
k
∑
j
=
1
m
w
i
j
\frac{w_{i_k}}{\sum\limits_{j=1}^mw_{i_j}}
j=1∑mwijwik。
求
1
1
1号猎人最后一个被射杀的概率。输出答案对
998244353
998244353
998244353取模。
w
i
>
0
,
1
≤
∑
w
i
≤
1
0
5
w_i>0,1\leq \sum w_i\leq 10^5
wi>0,1≤∑wi≤105
题解
转化题意
在不断射杀的过程中,概率的分母会不断改变。所以,我们可以稍微转化一下题意。
令
s
u
m
=
∑
i
=
1
n
w
i
sum=\sum\limits_{i=1}^nw_i
sum=i=1∑nwi,题意转化为:第
i
i
i个人被射杀的概率为
w
i
s
u
m
\dfrac{w_i}{sum}
sumwi,已经被射杀的人仍能继续被射杀。如果打中一个活着的人,那么这个人就死去。
为什么呢?因为在每次射杀的时候,每个活人被射杀的概率都他的仇恨度除以所有活人仇恨度的和。这样相对于原来,多了一个射杀死人的过程,但因为活人最终总会被射杀,所以本质上是一样的。
容斥
直接求
1
1
1号最后被射杀的概率比较困难,我们考虑使用容斥。
设在
1
1
1号之后被射杀的人的子集为
S
S
S的概率为
p
(
S
)
p(S)
p(S),则答案为
a
n
s
=
∑
(
−
1
)
∣
S
∣
p
(
S
)
ans=\sum(-1)^{|S|}p(S)
ans=∑(−1)∣S∣p(S)
然后,我们考虑如何求
p
(
S
)
p(S)
p(S)。
设
v
(
S
)
=
∑
i
∈
S
w
i
v(S)=\sum\limits_{i\in S}w_i
v(S)=i∈S∑wi。因为在
1
1
1号之后被射杀的人包含
S
S
S,所以就相当于射杀若干次,每次射杀除了
1
1
1号和集合
S
S
S之外的人,直到打中
1
1
1号。
p
(
S
)
=
∑
i
=
0
+
∞
(
s
u
m
−
w
1
−
v
(
S
)
s
u
m
)
i
⋅
w
1
s
u
m
p(S)=\sum\limits_{i=0}^{+\infty}(\dfrac{sum-w_1-v(S)}{sum})^i\cdot\dfrac{w_1}{sum}
p(S)=i=0∑+∞(sumsum−w1−v(S))i⋅sumw1
接下来求
∑
i
=
0
+
∞
(
s
u
m
−
w
1
−
v
(
S
)
s
u
m
)
i
\sum\limits_{i=0}^{+\infty}(\frac{sum-w_1-v(S)}{sum})^i
i=0∑+∞(sumsum−w1−v(S))i。因为
s
u
m
−
w
1
−
v
(
S
)
s
u
m
\frac{sum-w_1-v(S)}{sum}
sumsum−w1−v(S)在
[
0
,
1
)
[0,1)
[0,1)上,所以其无限次方为
0
0
0。由等比数列求和公式可得
∑
i
=
0
+
∞
(
s
u
m
−
w
1
−
v
(
S
)
s
u
m
)
i
=
1
−
(
s
u
m
−
w
1
−
v
(
S
)
s
u
m
)
+
∞
1
−
s
u
m
−
w
1
−
v
(
S
)
s
u
m
=
1
w
1
+
v
(
S
)
s
u
m
=
s
u
m
w
1
+
v
(
S
)
\sum\limits_{i=0}^{+\infty}(\frac{sum-w_1-v(S)}{sum})^i=\dfrac{1-(\frac{sum-w_1-v(S)}{sum})^{+\infty}}{1-\frac{sum-w_1-v(S)}{sum}}=\dfrac{1}{\frac{w_1+v(S)}{sum}}=\dfrac{sum}{w_1+v(S)}
i=0∑+∞(sumsum−w1−v(S))i=1−sumsum−w1−v(S)1−(sumsum−w1−v(S))+∞=sumw1+v(S)1=w1+v(S)sum
所以
p
(
S
)
=
s
u
m
w
1
+
v
(
S
)
⋅
w
1
s
u
m
=
w
1
w
1
+
v
(
S
)
p(S)=\dfrac{sum}{w_1+v(S)}\cdot \dfrac{w_1}{sum}=\dfrac{w_1}{w_1+v(S)}
p(S)=w1+v(S)sum⋅sumw1=w1+v(S)w1
那么
a
n
s
=
∑
(
−
1
)
∣
S
∣
w
1
w
1
+
v
(
S
)
ans=\sum(-1)^{|S|}\dfrac{w_1}{w_1+v(S)}
ans=∑(−1)∣S∣w1+v(S)w1
生成函数
依题意,
∑
w
i
≤
1
0
5
\sum w_i\leq 10^5
∑wi≤105,所以我们可以枚举
v
(
S
)
v(S)
v(S)
令
g
(
i
)
=
∑
v
(
S
)
=
i
(
−
1
)
∣
S
∣
g(i)=\sum\limits_{v(S)=i}(-1)^{|S|}
g(i)=v(S)=i∑(−1)∣S∣
那么
a
n
s
=
∑
i
=
0
s
u
m
g
(
i
)
⋅
w
1
w
1
+
v
(
S
)
ans=\sum\limits_{i=0}^{sum}g(i)\cdot\dfrac{w_1}{w_1+v(S)}
ans=i=0∑sumg(i)⋅w1+v(S)w1
那问题就转化为求
g
(
i
)
g(i)
g(i)了。用生成函数,
g
(
i
)
g(i)
g(i)其实就是
∏
i
=
2
n
(
1
−
x
w
i
)
\prod\limits_{i=2}^n(1-x^{w_i})
i=2∏n(1−xwi)的第
i
i
i次项的系数。
N
T
T
NTT
NTT
接下来,我们要用
N
T
T
NTT
NTT来求
∏
i
=
2
n
(
1
−
x
w
i
)
\prod\limits_{i=2}^n(1-x^{w_i})
i=2∏n(1−xwi)。
用分治,求
[
l
,
r
]
[l,r]
[l,r]的多项式时,先求出
[
l
,
m
i
d
]
[l,mid]
[l,mid]和
[
m
i
d
+
1
,
r
]
[mid+1,r]
[mid+1,r]的多项式,在将两个多项式相乘,即可求出。
我们把求的过程看作
log
n
\log n
logn层,每层的时间复杂度为
O
(
s
u
m
log
s
u
m
)
O(sum\log sum)
O(sumlogsum),所以总时间复杂度为
O
(
s
u
m
log
s
u
m
log
n
)
O(sum\log sum\log n)
O(sumlogsumlogn)。
总时间复杂度可以看作
O
(
n
log
2
n
)
O(n\log^2 n)
O(nlog2n)。
code
#include<bits/stdc++.h>
using namespace std;
long long ans=0,w[100005],f[500005],g[20][500005];
const long long G=3,mod=998244353;
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
void ch(long long *a1,int l){
for(int i=1,j=l/2;i<l-1;i++){
if(i<j) swap(a1[i],a1[j]);
int k=l/2;
while(j>=k){
j-=k;k>>=1;
}
j+=k;
}
}
void ntt(long long *a1,int l,int fl){
long long W,wn;
for(int i=2;i<=l;i<<=1){
if(fl==1) wn=mi(G,(mod-1)/i);
else wn=mi(G,mod-1-(mod-1)/i);
for(int j=0;j<l;j+=i){
W=1;
for(int k=j;k<j+i/2;k++,W=W*wn%mod){
long long t=a1[k],u=W*a1[k+i/2]%mod;
a1[k]=(t+u)%mod;
a1[k+i/2]=(t-u+mod)%mod;
}
}
}
if(fl==-1){
long long ny=mi(l,mod-2);
for(int i=0;i<l;i++) a1[i]=a1[i]*ny%mod;
}
}
int solve(int l,int r,long long *a1,int now){
if(l==r){
a1[0]=1;a1[w[l]]=mod-1;
return w[l];
}
int mid=l+r>>1,vt,len=1;
vt=solve(l,mid,a1,now+1)+solve(mid+1,r,g[now+1],now+1);
while(len<=vt) len<<=1;
ch(a1,len);ch(g[now+1],len);
ntt(a1,len,1);ntt(g[now+1],len,1);
for(int i=0;i<len;i++){
a1[i]=a1[i]*g[now+1][i]%mod;
g[now+1][i]=0;
}
ch(a1,len);
ntt(a1,len,-1);
return vt;
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%lld",&w[i]);
}
if(n==1){
printf("1");return 0;
}
int vt=solve(2,n,f,0);
for(int i=0;i<=vt;i++){
ans=(ans+f[i]*w[1]%mod*mi(w[1]+i,mod-2)%mod)%mod;
}
printf("%lld",ans);
return 0;
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)