首先要知道每次拿走最小才会达到最优,因为最小的不会给其他的提供任何加分,只有可能减小加分。
删除卡片的次序确定了,剩下的就是确定每段区间的左右端点。
pos[i] 表示数字 i 在初始序列中的位置。
首先枚举i (i = 1 -> n),如果不需删除,则将pos[i]放入set<int> S中,如果不需删除,则在S中二分查找上下界。
总的时间复杂度为o( (n-k)*log(k) )。
#include <algorithm> #include <iostream> #include <cstring> #include <cstdlib> #include <cstdio> #include <queue> #include <cmath> #include <stack> #include <map> #include <set> #include <ctime> #include <iomanip> #pragma comment(linker,"/STACK:1024000000"); #define EPS (1e-6) #define LL long long #define ULL unsigned long long #define INF 0x3f3f3f3f #define Mod 1000000007 #define mod 1000000007 /** I/O Accelerator Interface .. **/ #define g (c=getchar()) #define d isdigit(g) #define p x=x*10+c-'0' #define n x=x*10+'0'-c #define pp l/=10,p #define nn l/=10,n template<class T> inline T& RD(T &x) { char c; while(!d); x=c-'0'; while(d)p; return x; } template<class T> inline T& RDD(T &x) { char c; while(g,c!='-'&&!isdigit(c)); if (c=='-') { x='0'-g; while(d)n; } else { x=c-'0'; while(d)p; } return x; } inline double& RF(double &x) //scanf("%lf", &x); { char c; while(g,c!='-'&&c!='.'&&!isdigit(c)); if(c=='-')if(g=='.') { x=0; double l=1; while(d)nn; x*=l; } else { x='0'-c; while(d)n; if(c=='.') { double l=1; while(d)nn; x*=l; } } else if(c=='.') { x=0; double l=1; while(d)pp; x*=l; } else { x=c-'0'; while(d)p; if(c=='.') { double l=1; while(d)pp; x*=l; } } return x; } #undef nn #undef pp #undef n #undef p #undef d #undef g using namespace std; int num[1000010]; int pos[1000010]; bool ap[1000010]; int st[4001000]; set<int> s; int Init(int site,int l,int r) { if(l == r) return st[site] = 1; int mid = (l+r)>>1; return st[site] = Init(site<<1,l,mid) + Init(site<<1|1,mid+1,r); } int Query(int site,int L,int R,int l,int r) { if(L == l && R == r) return st[site]; int mid = (L+R)>>1; if(r <= mid) return Query(site<<1,L,mid,l,r); if(mid < l) return Query(site<<1|1,mid+1,R,l,r); return Query(site<<1,L,mid,l,mid) + Query(site<<1|1,mid+1,R,mid+1,r); } void Update(int site,int l,int r,int x) { if(l == r) { st[site] = 0; return ; } int mid = (l+r)>>1; if(x <= mid) Update(site<<1,l,mid,x); else Update(site<<1|1,mid+1,r,x); st[site] = st[site<<1] + st[site<<1|1]; } int main() { int n,k,i,j,x; scanf("%d %d",&n,&k); for(i = 1;i <= n; ++i) scanf("%d",&num[i]),pos[num[i]] = i; memset(ap,false,sizeof(ap)); for(i = 1;i <= k; ++i) scanf("%d",&x),ap[x] = true; set<int>::iterator it; LL sum = 0; Init(1,1,n); s.insert(n+1); s.insert(0); for(i = 1;i <= n; ++i) { if(ap[i]) { s.insert(pos[i]); continue; } it = s.upper_bound(pos[i]); int r = *it-1; int l = *(--it)+1; sum += Query(1,1,n,l,r); Update(1,1,n,pos[i]); } cout<<sum<<endl; return 0; }
时间: 2024-11-10 04:08:45