这个还是扫描线。注意:维护有多少个连续的区间的时候维护l,r分别代表左端点开始是否有线段,右端点开始(向左)是否有线段。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 100005
using namespace std;
struct tree
{
int left,right;
int sum,cnt,q,l,r;
};
struct seg
{
int x1,x2,h;
int flag;
};
tree node[maxn];
seg s[maxn];
int n,a,b,c,d,tot=0,hash[maxn],father[maxn],rec=0;
bool cmp(seg x,seg y)
{
if (x.h==y.h)
return x.flag>y.flag;
else return x.h<y.h;
}
void build(int i,int left,int right)
{
node[i].left=left;
node[i].right=right;
node[i].sum=0;
node[i].cnt=0;
node[i].q=0;
node[i].l=0;
node[i].r=0;
if (left==right)
{
father[left]=i;
return;
}
int mid=(left+right)>>1;
i=i<<1;
build(i,left,mid);
build(i|1,mid+1,right);
}
int find(int x)
{
int l=1,r=2*n;
for (;;)
{
int mid=(l+r)>>1;
if (hash[mid]==x) return mid;
else if (hash[mid]<x) l=mid+1;
else r=mid-1;
}
}
void pushup1(int i)
{
int left=node[i].left,right=node[i].right;
if (node[i].q>0) node[i].sum=hash[right+1]-hash[left];
else if (left==right) node[i].sum=0;
else node[i].sum=node[i<<1].sum+node[i<<1|1].sum;
}
void pushup2(int i)//danger!
{
int left=node[i].left,right=node[i].right;
if (node[i].q>0)
{
node[i].l=1;
node[i].r=1;
node[i].cnt=1;
}
else if (left==right)
{
node[i].l=0;
node[i].r=0;
node[i].cnt=0;
}
else
{
node[i].l=node[i<<1].l;
node[i].r=node[i<<1|1].r;
node[i].cnt=node[i<<1].cnt+node[i<<1|1].cnt;
if ((node[i<<1].r==1) && (node[i<<1|1].l==1))
node[i].cnt--;
}
}
void pushup(int i)
{
pushup1(i);
pushup2(i);
}
void modify(int i,int l,int r,int flag)
{
int left=node[i].left,right=node[i].right;
if ((l==left) && (r==right))
{
node[i].q=node[i].q+flag;
pushup(i);
return;
}
int mid=(left+right)>>1;
if (r<=mid) modify(i<<1,l,r,flag);
else if (l>=mid+1) modify(i<<1|1,l,r,flag);
else
{
modify(i<<1,l,mid,flag);
modify(i<<1|1,mid+1,r,flag);
}
pushup(i);
}
int main()
{
scanf("%d",&n);
build(1,1,2*n);
for (int i=1;i<=n;i++)
{
scanf("%d%d%d%d",&a,&b,&c,&d);
s[i*2-1].x1=a;s[i*2-1].x2=c;s[i*2-1].h=b;s[i*2-1].flag=1;
s[i*2].x1=a;s[i*2].x2=c;s[i*2].h=d;s[i*2].flag=-1;
hash[i*2-1]=a;hash[i*2]=c;
}
sort(s+1,s+2*n+1,cmp);
sort(hash+1,hash+2*n+1);
for (int i=1;i<=2*n-1;i++)
{
int l=find(s[i].x1),r=find(s[i].x2)-1;
modify(1,l,r,s[i].flag);
tot=tot+abs(node[1].sum-rec)+node[1].cnt*2*(s[i+1].h-s[i].h);
rec=node[1].sum;
}
tot=tot+(s[n*2].x2-s[n*2].x1);
printf("%d\n",tot);
return 0;
}