很好的题,可以用线段树做,也可以递推(dp?)
Mex
题意:求所有区间的mex和。mex值为没有在该区间出现过的最小非负整数。
先用比较好理解的线段树:
蒟蒻参考了多位dalao的题解,这里就不放链接了。。。
想法是每次求以i为起点的区间的mex值的和,最后累加即为答案。
可以先预处理出1为起点的区间的mex值,用它构造一棵线段树,要有求和和求最大值的操作。
为什么要求最大值呢? 往下看~
构造好线段树之后,1为起点的区间的mex和已经得到了,就是sum[1]。
接下来求2为起点的,首先第一步要删掉a[1],这里把mex[1]置为0。
删掉val=a[1]这个元素对后面的有什么影响呢?易知它不会对下一个出现val(下一个出现val的位置是nex[1])之后的位置产生影响,因为后面的区间已经包含val了(不少它这一个)
然后我们需要求得大于val的最大mex的位置pos,线段树单点查询最大值即可得到。
对于pos到nex[1] - 1的元素,所有的mex值都要变成a[i](不难想),成段更新线段树即可。
这样就得到了2为起点的区间的mex和,sum[1]。
下面就是重复了~~
代码中有几处很巧妙的地方,要细细领会~
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int maxn=200010; 4 #define CLR(m,a) memset(m,a,sizeof(m)) 5 #define ll long long 6 #define lson l,m,rt<<1 7 #define rson m+1,r,rt<<1|1 8 int maxmex[maxn<<2],_set[maxn<<2]; 9 ll sum[maxn<<2]; 10 11 int mex[maxn],vis[maxn]; 12 int head[maxn],nex[maxn]; 13 int a[maxn]; 14 15 void pushup(int rt){ 16 sum[rt]=sum[rt<<1]+sum[rt<<1|1]; 17 maxmex[rt]=max(maxmex[rt<<1],maxmex[rt<<1|1]); 18 } 19 void build(int l,int r,int rt){ 20 _set[rt]=0; 21 if(l==r){ 22 sum[rt]=maxmex[rt]=mex[l]; 23 return ; 24 } 25 int m=(l+r)>>1; 26 build(lson); 27 build(rson); 28 pushup(rt); 29 } 30 void pushdown(int rt,int len){ 31 if(_set[rt]){ 32 _set[rt<<1]=_set[rt<<1|1]=1; 33 sum[rt<<1]=maxmex[rt]*1ll*(len-(len>>1)); 34 sum[rt<<1|1]=maxmex[rt]*1ll*(len>>1); 35 maxmex[rt<<1]=maxmex[rt<<1|1]=maxmex[rt]; 36 _set[rt]=0; 37 } 38 } 39 void update(int L,int R,int v,int l,int r,int rt){ 40 if(L<=l&&r<=R){ 41 _set[rt]=1; 42 sum[rt]=1ll*(r-l+1)*v; 43 maxmex[rt]=v; 44 return ; 45 } 46 pushdown(rt,r-l+1); 47 int m=(l+r)>>1; 48 if(L<=m) update(L,R,v,lson); 49 if(R>m) update(L,R,v,rson); 50 pushup(rt); 51 } 52 53 int query(int x,int l,int r,int rt){ 54 if(l==r){ 55 return l; 56 } 57 pushdown(rt,r-l+1); 58 int m=(l+r)>>1; 59 int ans; 60 if(x<maxmex[rt<<1]) ans=query(x,lson); 61 else ans=query(x,rson); 62 return ans; 63 } 64 65 66 int main(){ 67 int n; 68 while(scanf("%d",&n)!=EOF&&n){ 69 CLR(vis,0); 70 int k=0; //没有这个优化就会超时,,差距甚大!!! 71 for(int i=1;i<=n;i++){ 72 scanf("%d",&a[i]); 73 if(a[i]>n) a[i]=n+1; //!!! 74 vis[a[i]]=1; 75 for(int j=k;j<=n;j++)if(!vis[j]){ 76 mex[i]=j; 77 k=j; 78 break; 79 } 80 } 81 CLR(head,0x3f); //!!! 82 for(int i=n;i>0;i--){ 83 nex[i]=head[a[i]]; 84 head[a[i]]=i; 85 } 86 // for(int i=1;i<=n;i++) printf("%d ",nex[i]);puts(""); 87 build(1,n,1); 88 ll ans=sum[1]; 89 for(int i=1;i<=n;i++){ 90 update(i,i,0,1,n,1); 91 if(maxmex[1]>a[i]){ 92 int pos=query(a[i],1,n,1); 93 if(pos<nex[i]) update(pos,nex[i]-1,a[i],1,n,1); 94 } 95 ans+=sum[1]; 96 } 97 printf("%lld\n",ans); 98 } 99 return 0; 100 }
下面是递推的方法:
感谢大神的分析http://blog.csdn.net/cc_again/article/details/11856847,可是对我等蒟蒻还是难以理解~~
又想了很久才有点明白orz▄█?█●,模拟一波~
我们求以i结尾的区间的mex和f[i],再累加就是答案。为什么要这样求呢?因为可以从f[i-1]推出f[i]。记f[i] = f[i-1] +temp 。
初始temp=0,下面求temp。
现在看求第f[i]的过程,现在数列中加入了新的数val=a[i]
首先可以想到的是,对于上一个出现val的位置p(即代码里的last[a[i]]),p之前的元素(即f[p-1],f[p-2],…… f[1])对f[i]的结果没有任何影响(因为在a[i]之前的区间里已经有val了)
现在考虑p到i之间的元素,可以知道,他们中可能有 因为没有val这个元素 而导致 到f[i-1]的mex是比小于等于val的,而现在加入了val,那么它的mex就会增大
那么加入val之后都对哪些元素有影响呢?
这里定义一个覆盖:如果j到i之间出现了0到x的所有值,那么称x的 覆盖为j。
对于新加入的val,求出val-1的覆盖j,那么对于p(上一个val的位置)+1到 min ( j , last[val] )之间的元素,他们到a[i]=val这个位置的mex都【至少】是val+1,本来他们到a[i-1]的mex是val,即都变大了1。此时temp += min ( j , last[val] ) - p。
上面为什么说至少呢?说明还有可能大于val+1。
举个例子,6 4 0 2 1 5 3 。现在处理第七位3, f[7] = f[6] + temp = 9 + temp 。
加入3之后,根据上面的步骤,我们先求得2的覆盖是3,所以1到3的mex至少变成了4(他们在加入3之前是3),但是显然6的mex是7,4的mex是6,这一步只使得a[3]的mex值正确了。 此时temp +=3。 temp=3。
这是因为前面已经出现了4,5,6,在加入3之前他们没有任何,加入3导致他们也有了作用。这就等价于我们又加入了4,5,6
所以我们还要求3的覆盖,4的覆盖,5的覆盖……,直到mex值完全正确。
步骤如下:
加入4,last[4]=2,3的覆盖是3,所以1到2 ( min(3,last[4] ) 的mex变成了5, 此时temp +=2。 temp=5。
继续,加入5,last[5]=6,4的覆盖是2,所以1到2 ( min(2,last[5] ) 的mex变成了6, 此时temp +=2。 temp=7。
继续,加入6,last[6]=1,5的覆盖是1,所以1到1 ( min(2,last[5] ) 的mex变成了7, 此时temp +=1。 temp=8。
到此就结束了。
所以f[i] = f[i-1] +temp = 9 + 8 =17 。
下面代码里的mp就是上面分析中说到的覆盖。
1 #include <bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 int a[200002],last[200002],mp[200002]; 5 int main() 6 { 7 int n,cnt; 8 while(scanf("%d",&n)&&n){ 9 for(int i=1;i<=n;i++) 10 scanf("%d",&a[i]); 11 12 ll ans=0; 13 ll temp=0,pre; 14 memset(last,0,sizeof(last)); 15 memset(mp,0,sizeof(mp)); 16 for(int i=1;i<=n;i++){ 17 if(a[i]<=n){ //如果大于n,不会影响任何元素,mex的和不变 18 pre=last[a[i]]; 19 last[a[i]]=i; 20 for(int j=a[i];j<=n;j++){ 21 if(j) mp[j]=min(mp[j-1],last[j]); 22 else mp[j]=last[j]; 23 if(mp[j]>pre){ 24 temp+=mp[j]-pre; 25 }else break; 26 } 27 } 28 ans+=temp; 29 } 30 printf("%I64d\n",ans); 31 } 32 return 0; 33 }