POJ 2104
题意
给定1到n的排列,每次询问某一区间内的第k小值。
样例输入
7 3 1 5 2 6 3 7 4 2 5 3 4 4 1 1 7 3
样例输出
5 6 3
主席树介绍
可持久化线段树,函数式线段树。 有点抽象,能够理解但还不是很熟练,代码不长,但是非常简练,有很多技巧,目前当做黑箱。
可持久化:每次操作尽量用新节点表示而不是修改原节点,这样就能保留所有历史信息。 函数式:函数式编程里变量常常是不变的,线段树的函数式写法就是这样。
我们用区间k小值来解释。
一些预处理
离散化(排序+去重),其实本道题不需要这个操作。 下面的代码用到了STL的很多技巧,用unique()
函数去重,用lower_bound()
重新映射a数组。
for (i = 1;i <= n; i++) scanf("%d",&a[i]);for (i = 1;i <= n; i++) b[i] = a[i];sort(b+1,b+n+1);k = unique(b+1,b+n+1) - (b+1);for (i = 1;i <= n; i++) a[i] = lower_bound(b+1,b+k+1,a[i])-b;假设求整个区间的k小值
这个问题可以用AVL树做,但是这里介绍一种类似平衡树的方法。假设某个节点的区间为[l,r],则这个节点记录的是在a数组中有多少个a[i]满足l<=a[i]<=r。这样搜索第k小值时,如果左孩子数量小于k则k小值在左子树中,反之则在右子树中。复杂度log(n)。
对于任意区间[L,R]
建立n棵线段树,每棵维护[1,i]的数字出现情况。 显然这n棵线段树每个节点代表的区间都是一样的,所以这n棵线段树同构。 用第R棵线段树去“减”第(L-1)棵线段树,得出来的结果就是区间[L,R]的情况,对这棵树套用一遍上面求整个区间的方法就可以求出[L,R]中的k小值。
如何节约空间
上面的方法看起来还是比较具体的,但是会MLE(n棵线段树)。 下面的优化就是主席树的精髓:如何扔掉重复的节点。有点抽象,这段话看懂了就比较轻松了。
我们发现,第i棵线段树和第i+1棵线段树的区别在于加入了a[i+1]这个数,而a[i+1]在第i棵树上从根出发向下走,走过的节点+1就变成了第i+1棵线段树。(你可以自己画一下看看有什么不同)
也就是说相邻两棵线段树之间不同节点个数至多为log(n)个,换句话说剩下这么多的节点都是一样的! 那么重复的节点就可以扔掉了。比如说一个节点的左孩子是重复的,那么我不需要多开一个节点,而是直接连到前一棵树上。 看起来比较复杂,但是编程中有很多技巧,最后代码比普通线段树还短。 P.S. 怕以后忘记这里写的会很详细。
sol
预处理这里就不再写了。
建树
现在连建树都要重新写了TAT。 其实只要建一棵空树即可,后面的树都是连到这棵树上。 但是后面再update和query的时候有一个问题:左孩子和右孩子并不能简单的乘2和乘2加1,如何解决?
//root[i]表示第i棵树的根的位置void build(int l,int r,int &rt){ rt = ++tot; sum[rt] = 0; if (l == r) return; int m = (l + r) >> 1; build(l,m,ls[rt]); build(m+1,r,rs[rt]);}...tot = 0;build(1,k,root[0]);用最朴素的方法:一个一个累加! 这里有一个技巧就是用了&,也就是说等到搜到这个点的时候自然会把这个点的位置给传回来。这个技巧剩下了不少代码,在后面的update和query中可以自己体会。
更新
//ls表示左孩子位置 rs表示右孩子位置 last表示前一棵树、当前节点的位置void update(int l,int r,int &rt,int last,int p){ rt = ++tot; ls[rt] = ls[last]; rs[rt] = rs[last];//暂时两个孩子都连到前一棵树的对应孩子上 sum[rt] = sum[last] + 1;//这一步可以解释是哪log(n)个点的值发生了修改! if (l == r) return; int m = (l + r) >> 1; if (p <= m) update(l,m,ls[rt],ls[last],p); else update(m+1,r,rs[rt],rs[last],p);//修改的那个节点开辟出一个新节点 ls/rs会回传新的节点的位置!前面讲到过}...for (i = 1;i <= n; i++) update(1,k,root[i],root[i-1],a[i]);这样一来就把这“n棵线段树”都建好了。可以看出虽然节点总数为nlog(n),但是却把所有的情况都记录下来了,这就是“可持久化”。
查询
int query(int ss,int tt,int l,int r,int k){ if (l == r) return l; int m = (l + r) >> 1; int cnt = sum[ls[tt]] - sum[ls[ss]];//用第tt棵线段树减去第ss棵线段树 if (k <= cnt) return query(ls[ss],ls[tt],l,m,k); else return query(rs[ss],rs[tt],m+1,r,k-cnt);}...while (q--) { scanf("%d%d%d",&ql,&qr,&qk); int res = query(root[ql-1],root[qr],1,k,qk); PRintf("%d/n",b[res]); }有了前面的铺垫,查询就比较简单了。
完整代码
#include<cmath>#include<cstdio>#include<vector>#include<cstring>#include<iomanip>#include<stdlib.h>#include<iostream>#include<algorithm>#define ll long long#define inf 1000000000#define mod 1000000007#define N 100000using namespace std;int a[N],b[N],root[N*20],ls[N*20],rs[N*20],sum[N*20];int n,q,i,tot,k,ql,qr,qk;void build(int l,int r,int &rt){ rt = ++tot; sum[rt] = 0; if (l == r) return; int m = (l + r) >> 1; build(l,m,ls[rt]); build(m+1,r,rs[rt]);}void update(int l,int r,int &rt,int last,int p){ rt = ++tot; ls[rt] = ls[last]; rs[rt] = rs[last]; sum[rt] = sum[last] + 1; if (l == r) return; int m = (l + r) >> 1; if (p <= m) update(l,m,ls[rt],ls[last],p); else update(m+1,r,rs[rt],rs[last],p);}int query(int ss,int tt,int l,int r,int k){ if (l == r) return l; int m = (l + r) >> 1; int cnt = sum[ls[tt]] - sum[ls[ss]]; if (k <= cnt) return query(ls[ss],ls[tt],l,m,k); else return query(rs[ss],rs[tt],m+1,r,k-cnt);}int main(){ cin>>n>>q; for (i = 1;i <= n; i++) scanf("%d",&a[i]); for (i = 1;i <= n; i++) b[i] = a[i]; sort(b+1,b+n+1); k = unique(b+1,b+n+1) - (b+1); for (i = 1;i <= n; i++) a[i] = lower_bound(b+1,b+k+1,a[i])-b; tot = 0; build(1,k,root[0]); for (i = 1;i <= n; i++) update(1,k,root[i],root[i-1],a[i]); while (q--) { scanf("%d%d%d",&ql,&qr,&qk); int res = query(root[ql-1],root[qr],1,k,qk); printf("%d/n",b[res]); } return 0;}