Splay真心比sbt难调多了QAQ
调一个remove一上午终于搞好了QAQ
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #define P 1000000 #define MAXINT 100000000000 using namespace std; struct splay { splay *ch[2],*fa;//ch[0]为左子树ch[1]为右子树fa为父节点 int data,cnt,size;//data节点数值cnt重复元素个数size树的大小 }*root,*no; int ans; int n,key,flag,pet,human; void calc(splay *x) { x->size=x->cnt; if (x->ch[0]) x->size+=x->ch[0]->size; if (x->ch[1]) x->size+=x->ch[1]->size; } void rot(splay *x,bool flag)//flag=0左旋flag=1右旋 节省编程复杂度的小优化 { splay *y=x->fa; y->ch[!flag]=x->ch[flag]; if (x->ch[flag]!=NULL) x->ch[flag]->fa=y; x->fa=y->fa; if (y->fa!=NULL) if (y->fa->ch[0]==y) y->fa->ch[0]=x; else y->fa->ch[1]=x; x->ch[flag]=y; y->fa=x; if (y==root) root=x; } void Splay(splay *x,splay *f)//将节点X提到节点f下 (地位与sbt Maintain()相同的操作0-0) { if (x==f||x==NULL) return; while (x->fa!=f) { if (x->fa->fa==f) { if (x->fa->ch[0]==x) rot(x,1); else rot(x,0); } else { splay *y=x->fa,*z=y->fa; if (z->ch[0]==y) if (y->ch[0]==x) rot(y,1),rot(x,1); else rot(x,0),rot(x,1); else if (y->ch[0]==x) rot(x,1),rot(x,0); else rot(y,0),rot(x,0); } } if (f==NULL) root=x; if (x!=NULL)calc(x); if (f!=NULL)calc(f); } void insert(int data) { splay *n=root,*x=NULL; while (n!=NULL) { x=n; if (data>=n->data) n=n->ch[1]; else n=n->ch[0]; } n=new splay; n->ch[0]=n->ch[1]=NULL; n->size=1; n->data=data; n->cnt=1; n->fa=x; if (x==NULL) root=n; else if (data>=x->data) x->ch[1]=n; else x->ch[0]=n; if (n!=root) Splay(n,NULL); } splay* find(int data) { if (root==NULL) return NULL; splay *x=root,*y=NULL; while (x!=NULL) { if (data>x->data) y=x,x=x->ch[1]; else if (data<x->data) y=x,x=x->ch[0]; else { y=x; return y; } } return NULL; } splay* getmax() { if (root==NULL) return NULL; splay *x=root; while (x->ch[1]) x=x->ch[1]; return x; } splay* getmin() { if (root==NULL) return NULL; splay *x=root; while (x->ch[0]) x=x->ch[0]; return x; } int rank(int data) { splay *x=find(data); Splay(x,root); if (root->ch[0]==x) return x->ch[0]->size+1; else return root->ch[0]->size+1+x->ch[0]->size+1; } splay* select(int k) { splay *x=root; while (x) { if (k>x->ch[0]->size+1) { k-=x->ch[0]->size+1; x=x->ch[1]; } else if (k<x->ch[0]->size+1) { x=x->ch[0]; } else return x; } } splay* pred(splay *x,splay *y,int data) { if (x==NULL) return y; if (data==x->data) return y; if (data>x->data) return pred(x->ch[1],x,data); else return pred(x->ch[0],y,data); } splay* succ(splay *x,splay *y,int data) { if (x==0) return y; if (x->data==data) return x; if (x->data>data) return succ(x->ch[0],x,data); else return succ(x->ch[1],y,data); } void remove(int data) { if (root==NULL) return; splay *x=find(data),*y; if (x==NULL) return; Splay(x,NULL); if (!x->ch[0]&&!x->ch[1]) root=NULL; if (x->ch[0]&&!x->ch[1]) root=x->ch[0],x->ch[0]->fa=NULL; if (!x->ch[0]&&x->ch[1]) root=x->ch[1],x->ch[1]->fa=NULL; if (x->ch[0]&&x->ch[1]) { y=x->ch[1]; while (y->ch[0]) y=y->ch[0]; Splay(y,NULL); y->fa=NULL; y->ch[0]=x->ch[0]; x->ch[0]->fa=y; calc(y); root=y; } } int main() { freopen("pet.in","r",stdin); freopen("pet.out","w",stdout); scanf("%d",&n); for (int i=1;i<=n;i++) { scanf("%d",&flag); scanf("%d",&key); if (flag==1) { if (i==1) insert(key),human++; else if (!pet) insert(key),human++; else if (pet) { splay *t1,*t2; long long d1,d2; long long a1,a2; t1=pred(root,no,key); t2=succ(root,no,key); if (t1||t2) { if (t1==no) d1=MAXINT; else d1=t1->data; if (t2==no) d2=MAXINT; else d2=t2->data; //if (!t1&&!t2) continue; a1=abs(key-d1); a2=abs(key-d2); if (a1==a2) { if (d1<d2) { ans=(ans+a1)%P; remove(d1); pet--; } else if (d1>d2) { ans=(ans+a2)%P; remove(d2); pet--; } } else if (a1<a2) { ans=(ans+a1)%P; remove(d1); pet--; } else if (a1>a2) { ans=(ans+a2)%P; remove(d2); pet--; } } } } else { if (i==1) insert(key),pet++; else if (!human) insert(key),pet++; else if (human) { splay *t1,*t2; long long d1,d2; long long a1,a2; t1=pred(root,no,key); t2=succ(root,no,key); if (t1||t2) { if (t1==no) d1=MAXINT; else d1=t1->data; if (t2==no) d2=MAXINT; else d2=t2->data; //if (!t1&&!t2) continue; a1=abs(key-d1); a2=abs(key-d2); if (a1==a2) { if (d1<d2) { ans=(ans+a1)%P; remove(d1); human--; } else if (d1>d2) { ans=(ans+a2)%P; remove(d2); human--; } } else if (a1<a2) { ans=(ans+a1)%P; remove(d1); human--; } else if (a1>a2) { ans=(ans+a2)%P; remove(d2); human--; } } } } } cout<<ans; }
时间: 2024-10-12 21:47:10