在没有修改操作时,应用划分树可以在O(MlogN)时间内解决查找区间第K小的问题,但是在引入修改(将原序列中的某个值改为另一个值)之后,划分树就不行了。
这时,需要数据结构联合的思想。
可以观察一下:
(1)区间操作:使用线段树;
(2)修改值(其实是先删除再插入)和找第K小:使用平衡树;
现在这两种操作都有,应该使用线段树+平衡树
准确来说是线段树套平衡树,即对原序列建立一棵线段树,其中的每个结点内套一棵对该结点管辖区间内的平衡树。

<1>结点类型(结构):
struct seg_node {
    
int l, r, mid, lch, rch, rt;
} T0[MAXN0];
struct SBT_node {
    
int v, l, r, p, sz0, sz, mul;
} T[MAXN];
其中seg_node是线段树结点类型,SBT_node是平衡树(SBT)结点类型。需要注意的是seg_node里面的rt域(root的缩写),它是该结点内套的平衡树的根结点下标索引(因为对于任意一棵平衡树,只要知道了其根结点就可以遍历整棵树)。

<2>建树:
建树是线段树和平衡树一起建。在建立线段树结点的时候,先建立一棵空的平衡树(rt域置0),然后再在平衡树里面逐个插入该结点管辖区间内的所有元素即可;

<3>修改:
修改操作要注意:如果要将A[x](A为原序列)的值修改为y,则需要自顶向下遍历整棵线段树,将所有包含了A[x]的结点内的平衡树全部执行“删除v=A[x](这个可以通过真正维护一个序列得到),再插入y”的操作;

<4>找区间第K小:
这个操作极其麻烦。需要借助二分。
设要在区间[l, r]中找到第K小。首先将[l, r]拆分成若干个线段树结点,然后二分一个值x,在这些结点的平衡树中找到x的rank(这里的rank指平衡树中有多少个值比x小,不需要加1),加起来,最后再加1,就是x在[l, r]中的总名次。问题是,设[l..r]中第K小的数为v1,第(K+1)小的数为v2(如果不存在的话,v2=+∞),则[v1, v2)内的数都是“第K小”的。因此,不能二分数字,而应该二分元素。设S[i]为原序列中第i小的数,二分i,然后在根结点的平衡树中找到第i小的即为S[i],再求其名次,这样直到找到总名次为K的元素为止。问题还没完,序列中可能有元素的值相同,这时可能永远也找不到第K小的(比如序列1 2 3 3 3 4 5,K=4,若“序列中比x小的元素总数+1”为x的名次,则永远也找不到第4小的),因此,若这样求出的“名次”小于等于K,都应该将下一次的左边界设为mid而不是(mid+1),而“名次”大于K时,该元素肯定不是第K小的,所以下一次右边界设为(mid-1)。

代码(本机测最猥琐数据4s以内,交到ZJU上TLE,不知为什么,神犇指点一下,3x):
#include <iostream>
#include 
<stdio.h>
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
#define re3(i, l, r) for (int i=l; i<=r; i++)
const int MAXN0 = 110000, MAXN = 930000, INF = ~0U >> 2;
struct seg_node {
    
int l, r, mid, lch, rch, rt;
} T0[MAXN0];
struct SBT_node {
    
int v, l, r, p, sz0, sz, mul;
} T[MAXN];
int No0, No, n, root, rt0, a[MAXN0 >> 1], b[MAXN0 >> 1], l1, r1, len;
void slc(int _p, int _c)
{
    T[_p].l 
= _c; T[_c].p = _p;
}
void src(int _p, int _c)
{
    T[_p].r 
= _c; T[_c].p = _p;
}
void upd(int x)
{
    T[x].sz0 
= T[T[x].l].sz0 + T[T[x].r].sz0 + T[x].mul;
    T[x].sz 
= T[T[x].l].sz + T[T[x].r].sz + 1;
}
void lrot(int x)
{
    
int y = T[x].p; if (y == rt0) T[rt0 = x].p = 0else {int p = T[y].p; if (y == T[p].l) slc(p, x); else src(p, x);}
    src(y, T[x].l); slc(x, y); T[x].sz0 
= T[y].sz0; T[x].sz = T[y].sz; upd(y);
}
void rrot(int x)
{
    
int y = T[x].p; if (y == rt0) T[rt0 = x].p = 0else {int p = T[y].p; if (y == T[p].l) slc(p, x); else src(p, x);}
    slc(y, T[x].r); src(x, y); T[x].sz0 
= T[y].sz0; T[x].sz = T[y].sz; upd(y);
}
void maintain(int x, bool ff)
{
    
int z;
    
if (ff) {
        
if (T[T[T[x].r].r].sz > T[T[x].l].sz) {z = T[x].r; lrot(z);}
        
else if (T[T[T[x].r].l].sz > T[T[x].l].sz) {z = T[T[x].r].l; rrot(z); lrot(z);} else return;
    } 
else {
        
if (T[T[T[x].l].l].sz > T[T[x].r].sz) {z = T[x].l; rrot(z);}
        
else if (T[T[T[x].l].r].sz > T[T[x].r].sz) {z = T[T[x].l].r; lrot(z); rrot(z);} else return;
    }
    maintain(T[z].l, 
0); maintain(T[z].r, 1); maintain(z, 0); maintain(z, 1);
}
int find(int _v)
{
    
int i = rt0, v0;
    
while (i) {
        v0 
= T[i].v;
        
if (_v == v0) return i; else if (_v < v0) i = T[i].l; else i = T[i].r;
    }
    
return 0;
}
void ins(int _v)
{
    
if (!rt0) {
        T[
++No].v = _v; T[No].l = T[No].r = T[No].p = 0; T[No].sz0 = T[No].sz = T[No].mul = 1; rt0 = No;
    } 
else {
        
int i = rt0, j, v0;
        
while (1) {
            T[i].sz0
++; v0 = T[i].v;
            
if (_v == v0) {T[i].mul++return;} else if (_v < v0) j = T[i].l; else j = T[i].r;
            
if (j) i = j; else break;
        }
        T[
++No].v = _v; T[No].l = T[No].r = 0; T[No].sz0 = T[No].sz = T[No].mul = 1if (_v < v0) slc(i, No); else src(i, No);
        
while (i) {T[i].sz++; maintain(i, _v > T[i].v); i = T[i].p;}
    }
}
void del(int x)
{
    
if (T[x].mul > 1) {
        T[x].mul
--;
        
while (x) {T[x].sz0--; x = T[x].p;}
    } 
else {
        
int l = T[x].l, r = T[x].r;
        
if (!|| !r) {
            
if (x == rt0) T[rt0 = l + r].p = 0else {
                
int p = T[x].p; if (x == T[p].l) slc(p, l + r); else src(p, l + r);
                
while (p) {T[p].sz0--; T[p].sz--; p = T[p].p;}
            }
        } 
else {
            
int i = l, j;
            
while (j = T[i].r) i = j;
            T[x].v 
= T[i].v; T[x].mul = T[i].mul; int p = T[i].p; if (i == T[p].l) slc(p, T[i].l); else src(p, T[i].l);
            
while (p) {upd(p); p = T[p].p;}
        }
    }
}
int Find_Kth(int K)
{
    
int i = rt0, s0, m0;
    
while (i) {
        s0 
= T[T[i].l].sz0; m0 = T[i].mul;
        
if (K <= s0) i = T[i].l; else if (K <= s0 + m0) return T[i].v; else {K -= s0 + m0; i = T[i].r;}
    }
}
int rank(int _v)
{
    
int i = rt0, tot = 0, v0;
    
while (i) {
        v0 
= T[i].v;
        
if (_v == v0) {tot += T[T[i].l].sz0; return tot;} else if (_v < v0) i = T[i].l; else {tot += T[T[i].l].sz0 + T[i].mul; i = T[i].r;}
    }
    
return tot;
}
int mkt(int l, int r)
{
    T0[
++No0].l = l; T0[No0].r = r; int mid = l + r >> 1; T0[No0].mid = mid; rt0 = 0;
    re3(i, l, r) ins(a[i]); T0[No0].rt 
= rt0;
    
if (l < r) {int No00 = No0; T0[No00].lch = mkt(l, mid); T0[No00].rch = mkt(mid + 1, r); return No00;} else {T0[No0].lch = T0[No0].rch = 0return No0;}
}
void fs(int x)
{
    
if (x) {
        
int l0 = T0[x].l, r0 = T0[x].r;
        
if (l0 >= l1 && r0 <= r1) b[len++= T0[x].rt; else if (l0 > r1 || r0 < l1) returnelse {fs(T0[x].lch); fs(T0[x].rch);}
    }
}
void C(int x, int _v)
{
    
int i = root, l0, r0, mid0, v0 = a[x], N;
    
while (i) {
        l0 
= T0[i].l; r0 = T0[i].r; mid0 = T0[i].mid; rt0 = T0[i].rt;
        N 
= find(v0); del(N); ins(_v); T0[i].rt = rt0;
        
if (x <= mid0) i = T0[i].lch; else i = T0[i].rch;
    }
    a[x] 
= _v;
}
int Q(int K)
{
    len 
= 0; fs(root);
    
int ls = 1, rs = n, mids, midv, tot;
    
while (ls < rs) {
        mids 
= ls + rs + 1 >> 1; rt0 = T0[root].rt; midv = Find_Kth(mids);
        tot 
= 1; re(i, len) {rt0 = b[i]; tot += rank(midv);}
        
if (tot <= K) ls = mids; else rs = mids - 1;
    }
    rt0 
= T0[root].rt; return Find_Kth(ls);
}
int main()
{
    
int tests, m, x, y, K;
    
char ch;
    scanf(
"%d"&tests);
    re(testno, tests) {
        scanf(
"%d%d"&n, &m); No0 = No = 0;
        re(i, n) scanf(
"%d"&a[i]); ch = getchar();
        root 
= mkt(0, n - 1);
        re(i, m) {
            ch 
= getchar();
            
if (ch == 'C') {
                scanf(
"%d%d%*c"&x, &y);
                C(
--x, y);
            } 
else {
                scanf(
"%d%d%d%*c"&l1, &r1, &K);
                l1
--; r1--; printf("%d\n", Q(K));
            }
        }
    }
    
return 0;
}

只有注册用户登录后才能发表评论。
网站导航: 博客园   IT新闻   BlogJava   知识库   博问   管理