ZOJ 2112
题意
给你n个数,有q次操作,每次操作可以修改某一个数,或是求区间第k小值。(多组数据)
样例输入
2 5 3 3 2 1 4 7 Q 1 4 3 C 2 6 Q 2 5 3 5 3 3 2 1 4 7 Q 1 4 3 C 2 6 Q 2 5 3
样例输出
3 6 3 6
SOL
如果不考虑单点修改就是主席树裸题。主席树本质上使用前缀和维护的,查询复杂度为O(1),但修改复杂度为O(n)。如果不用前缀和,查询复杂度为O(n),修改复杂度为O(n)。(这不是废话么…) 考虑在外面套上一层树状数组,使得查询和修改复杂度均为O(logn)。
简单的实现方式
普通的树状数组的每一个点存的都是某几个数的和(和lowbit有关),那么本道题的每一个点存的都是一棵线段树,并且相邻线段树之间用主席树的方式连接。
时间复杂度分析
修改:每次最多对log(n)棵线段树修改,每个节点中存的线段树要修改log(n)次,最多加入(n+m)个数(最开始要放入n个数),所以时间复杂度为O((n+m)lognlogn)。 查询:每次需要累加最多log(n)棵线段树,每棵线段树累加log(n)个节点,所以时间复杂度为O(mlognlogn)。
空间复杂度分析
每次操作修改log(n)棵线段树,每棵线段树修改log(n)个节点,所以空间复杂度为O((n+m)lognlogn)。
从这里可以可以看出时间复杂度和空间复杂度都是nlognlogn级别的,但是在ZOJ上内存限制只有64M,也就是说这种方式会MLE,如何解决?
静态建树+动态修改
分析数据范围可以看出这道题目的m比较小,所以可以选择一种神奇的方法:把查询的区间分为两步,第一步求出原区间,第二步加上修改的增量。第一步显然就是裸的主席树,第二步也就是没有初值的修改,不同之处在于只要修改m就可以了。空间复杂度级数没有变,但是常数大概能小4~6倍,空间刚好可以卡过去。下面就用这种方法具体解释如何实现。
建树
为了自己能够更好的掌握,这里的建树和更新方法都采用裸二叉树的方法。(还有许多方法,比如不需要用递归实现的,以后有空可以学习一下,基本上递归方式能够看懂非递归方式就很简单了)
void build(int l,int r,int &rt){ rt = ++tot; sum[rt] = 0; if (l == r) return; int m = (l + r) >> 1; build(l,m,ls[rt]); build(m+1,r,rs[rt]);}更新
void update(int l,int r,int &rt,int last,int p,int delta){ rt = ++tot; ls[rt] = ls[last]; rs[rt] = rs[last]; sum[rt] = sum[last] + delta; if (l == r) return; int m = (l + r) >> 1; if (p <= m) update(l,m,ls[rt],ls[last],p,delta); else update(m+1,r,rs[rt],rs[last],p,delta);}查询
个人感觉最难的部分就是查询。从这里可以看出非递归方式写起来真的非常清爽。。。特别是当树套树一层一层搞不清楚的时候写成递归真的要死。。。 这里的S数组存的是树状数组里面每个节点代表的线段树的根的编号(可以对比:root数组存的是原始数组里面每个节点代表的线段树的根的编号),怎么更新S后面会讲到。use1/2数组存的就是L-1和R的lowbit路径。 int cnt = value2(R) - value1(L-1) + sum[ls[rrt]] - sum[ls[lrt]];
这句话就是我刚才提到的“把查询的区间分为两步,第一步求出原区间,第二步加上修改的增量”。显然value过程求的是增量,sum数组显然是原区间的情况。
int value1(int x){ int re = 0; while (x > 0) {re += sum[ls[use1[x]]]; x -= lowbit(x);} return re;}int value2(int x){ int re = 0; while (x > 0) {re += sum[ls[use2[x]]]; x -= lowbit(x);} return re;}int Query(int L,int R,int k){ int lrt = root[L-1]; int rrt = root[R]; int l = 1,r = mm; for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = S[i]; for (int i = R;i ; i -= lowbit(i)) use2[i] = S[i]; while (l < r) { int m = (l + r) >> 1; int cnt = value2(R) - value1(L-1) + sum[ls[rrt]] - sum[ls[lrt]]; if (k <= cnt) { r = m; for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = ls[use1[i]]; for (int i = R;i ; i -= lowbit(i)) use2[i] = ls[use2[i]]; lrt = ls[lrt]; rrt = ls[rrt]; } else { l = m + 1; k = k - cnt; for (int i = L - 1;i > 0; i -= lowbit(i)) use1[i] = rs[use1[i]]; for (int i = R;i > 0; i -= lowbit(i)) use2[i] = rs[use2[i]]; lrt = rs[lrt]; rrt = rs[rrt]; } } return l;}修改
S数组一开始全部连到root[0]上表示全部为空。接下来每次修改一个数就新开一棵线段树并且用S记录位置。
void change(int x,int p,int delta){ while (x<=n) { update(1,k,S[x],S[x],p,delta); x += lowbit(x); }}完整代码
#include<cmath>#include<cstdio>#include<vector>#include<cstring>#include<iomanip>#include<stdlib.h>#include<iostream>#include<algorithm>#define ll long long#define inf 1000000000#define mod 1000000007#define N 2500010#define M 60010using namespace std;struct data{int kind,l,r,k;}query[10010];char op[10];int T,n,mm,q,i,tot,k;int a[M],b[M],sum[N],ls[N],rs[N],root[M],S[M],use1[M],use2[M];void build(int l,int r,int &rt){ rt = ++tot; sum[rt] = 0; if (l == r) return; int m = (l + r) >> 1; build(l,m,ls[rt]); build(m+1,r,rs[rt]);}void update(int l,int r,int &rt,int last,int p,int delta){ rt = ++tot; ls[rt] = ls[last]; rs[rt] = rs[last]; sum[rt] = sum[last] + delta; if (l == r) return; int m = (l + r) >> 1; if (p <= m) update(l,m,ls[rt],ls[last],p,delta); else update(m+1,r,rs[rt],rs[last],p,delta);}int lowbit(int x){return x&(-x);}int value1(int x){ int re = 0; while (x > 0) {re += sum[ls[use1[x]]]; x -= lowbit(x);} return re;}int value2(int x){ int re = 0; while (x > 0) {re += sum[ls[use2[x]]]; x -= lowbit(x);} return re;}int Query(int L,int R,int k){ int lrt = root[L-1]; int rrt = root[R]; int l = 1,r = mm; for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = S[i]; for (int i = R;i ; i -= lowbit(i)) use2[i] = S[i]; while (l < r) { int m = (l + r) >> 1; int cnt = value2(R) - value1(L-1) + sum[ls[rrt]] - sum[ls[lrt]]; if (k <= cnt) { r = m; for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = ls[use1[i]]; for (int i = R;i ; i -= lowbit(i)) use2[i] = ls[use2[i]]; lrt = ls[lrt]; rrt = ls[rrt]; } else { l = m + 1; k = k - cnt; for (int i = L - 1;i > 0; i -= lowbit(i)) use1[i] = rs[use1[i]]; for (int i = R;i > 0; i -= lowbit(i)) use2[i] = rs[use2[i]]; lrt = rs[lrt]; rrt = rs[rrt]; } } return l;}void change(int x,int p,int delta){ while (x<=n) { update(1,k,S[x],S[x],p,delta); x += lowbit(x); }}int hash(int x){ return lower_bound(b+1,b+k+1,x)-b;}int main(){ cin>>T; while (T--) { cin>>n>>q; for (i = 1;i <= n; i++) scanf("%d",&a[i]); for (i = 1;i <= n; i++) b[i] = a[i]; k = n; for (i = 1;i <= q; i++) { scanf("%s",op); if (op[0] == 'Q') { query[i].kind = 0; scanf("%d%d%d",&query[i].l,&query[i].r,&query[i].k); } else { query[i].kind = 1; scanf("%d%d",&query[i].l,&query[i].r); b[++k] = query[i].r; } } sort(b+1,b+k+1); k = unique(b+1,b+k+1) - (b+1); tot = 0; build(1,k,root[0]); for (i = 1;i <= n; i++) update(1,k,root[i],root[i-1],hash(a[i]),1); for (i = 1;i <= n; i++) S[i] = root[0]; mm = k; for (i = 1;i <= q; i++) { if (query[i].kind == 0) PRintf("%d/n",b[Query(query[i].l,query[i].r,query[i].k)]); else { change(query[i].l,hash(a[query[i].l]),-1); change(query[i].l,hash(query[i].r),1); a[query[i].l] = query[i].r; } } } return 0;}