[ABC215G]Colorful Candies 2 题解

期望

Statement

G - Colorful Candies 2 (atcoder.jp)

给定 \(n\) 个糖果,第 \(i\) 个糖果颜色为 \(c_i\)

对于每个 \(k=1∼n\),求随机选出 \(k\) 个糖果,\(\binom nk\) 种情况中糖果颜色数的期望。答案模 \(998244353\)。

\(n\le 5\times 10^4\)

Solution

知道这类问题的一般套路都是先利用期望的线性性转化问题求解

但是不会用期望的线性性,于是最开始设了一个 \(dp[i][j][k]\) 表示前 \(i\) 种颜色中选择 \(j\) 个,选出 \(k\) 不同颜色的方案数,然后前缀和优化之后得到一个 \(O(n^3)\) 后发现状态实在减不动了,GG

回到正题,注意到本题中对答案有影响的是不同颜色的数量和每种颜色数量,并不关注每种颜色具体是什么

枚举 \(k\) ,题目要求计算 \(E(\sum x_i)\) ,\(x_i=1/0\) 表示选择 \(k\) 个数,颜色 \(i\) 是/否被选中

有 \(E(\sum x_i)=\sum E(x_i)\) ,考虑求出选择 \(k\) 个数中有 \(i\) 的期望

简单容斥一下,用总方案数-选不中方案数,设颜色 \(i\) 有 \(cnt_i\) 个

\[E(x_i)=\dfrac{\binom nk-\binom {n-cnt_i}k}{\binom nk} \]

暴力计算是 \(O(n^2)\) 的,考虑优化

注意到其实 \(cnt\) 相同的可以和在一起算,又 \(\sum cnt_i=n\) ,所以不同的 \(cnt\) 只有 \(\sqrt n\) 种,复杂度来到 \(O(n\sqrt n)\)

Code

#include<bits/stdc++.h> using namespace std; const int N = 5e4+5; const int mod = 998244353;  char buf[1<<23],*p1=buf,*p2=buf; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) int read(){     int s=0,w=1; char ch=getchar();     while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}     while(isdigit(ch))s=s*10+(ch^48),ch=getchar();     return s*w; } int ksm(int a,int b){     int res=1;     while(b){         if(b&1)res=1ll*res*a%mod;         a=1ll*a*a%mod,b>>=1;     }     return res; } void inc(int &a,int b){a=a>=mod-b?a-mod+b:a+b;}  vector<int>vec; int jc[N],invj[N],a[N],b[N]; int tong[N],cnt[N]; bool vis[N]; int n,m;  int C(int n,int m){     if(n<m||m<0)return 0;     return 1ll*jc[n]*invj[m]%mod*invj[n-m]%mod; }  signed main(){     n=read();     for(int i=jc[0]=1;i<=n;++i)         jc[i]=1ll*jc[i-1]*i%mod,a[i]=b[i]=read();     invj[n]=ksm(jc[n],mod-2);     for(int i=n-1;~i;--i)invj[i]=1ll*invj[i+1]*(i+1)%mod;     sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;     for(int i=1;i<=n;++i)         tong[lower_bound(b+1,b+1+m,a[i])-b]++;     for(int i=1;i<=m;++i){         cnt[tong[i]]++;         if(vis[tong[i]]==0)             vec.push_back(tong[i]),vis[tong[i]]=1;     }     for(int k=1;k<=n;++k){         int ans=0,all=C(n,k),invl=ksm(all,mod-2);         for(auto v:vec)             inc(ans,1ll*cnt[v]*((all-C(n-v,k)+mod)%mod)%mod*invl%mod);         printf(%d\n,ans);     }     return 0; }