线段树(Segment Tree)是一种在数组之上维护的树结构(也叫statistic tree),它能够存储一个数组a
,支持下面两个操作:
- 以$O(lgn)$的时间查询数组上的一个区间
a[l:r]
内的某种统计值,比如查询区间内的元素之和,或区间内的最小值最大值等。 - 支持以$O(lgn)$的时间修改数组中的某一元素。
基础思想
本节按照求区间和需求解释如何构建并维护一棵线段树。其描述如下:
- 对于数组
a[0:n-1]
,树的根节点存储整个区间的和,即sum(a[0:n-1])
,等价于其维护区间[0:n-1]
; - 树的任意中间结点,若维护区间
[l:r]
,则其左子树维护[l:(l+r)/2]
,右子树维护区间[(l+r)/2+1:r]
; - 若
l==r
,则此为叶子节点,不再分裂。
这些树节点的存储可以借鉴二叉堆的思想,用一个数组t
来存各层的树结点:根结点在t[0]
,其左孩子位于t[1]
,右孩子位于t[2]
……以此类推,即t[i]
的左孩子为t[2i+1]
,右孩子为t[2i+2]
。
复杂度分析
树高
显然树的构建可以递归的完成,因为每次分裂均对区间长度做二分处理,因此树的高度在$lgn$这一量级;
数据更新
若需要更新数组a[i]=x
,则显而易见的做法是从树根递归向下直到对应a[i]
的叶子结点,将受影响的区间和进行修改,而每一层的所有区间是不相交的,仅会有一个区间包含a[i]
,因此这一过程时间复杂度$O(lgn)$。
求区间和
求区间a[l:r]
的和时,依然是从树根开始递归向下,只是这里可能会出现触发多次递归的情况:
- 若当前结点维护的区间
[tl:tr]
恰好等于[l:r]
,则直接返回其和; - 若当前区间
[tl:tr]
被[l:r]
的左半区间,或右半区间包含,则递归至对应子树; - 若当前区间与左右半区间均有交集,记中间点为
m
,则依次递归左子树查询区间[l:m]
,递归右子树查询区间[m+1,r]
,再求和。
我们证明上面的递归过程中访问的结点数量不超过$4lgn$,从而时间复杂度为$O(lgn)$。
实际上,我们只需要证明树的每一层中访问到的结点数量不超过4即可。通过归纳法证明,首先在第一层(树根)显然仅1个结点。其次,对于任意层:
- 若访问了2个结点,则下一层最多访问4个结点,因为每个结点最多触发两次递归;
- 若访问了3或4个结点,则这些结点构成的区间必然是相连的,中间的1个或2个结点必然恰好等于查询的区间,不会触发下一层的递归,仅有最左侧和最右侧的结点可能触发递归,而这也导致下一层不超过4个;
基础实现
对于上面提到的实现,长度$n$的数组,其产生的树高为$\lceil lgn \rceil+1$,因此数组大小至多$4n$。通过递归可以完成所有操作。
class SegmentTree {
vector<int> t;
int n;
int build(int idx, int l, int r, const vector<int>& a)
{
if (l == r)
return t[idx] = a[l];
int mid = (l + r) / 2;
return t[idx] = build(2 * idx + 1, l, mid, a) +
build(2 * idx + 2, mid + 1, r, a);
}
void update(int i, int val, int idx, int l, int r)
{
if (l == r) {
t[idx] = val;
return;
}
int mid = (l + r) / 2;
if (i <= mid)
update(i, val, 2 * idx + 1, l, mid);
else
update(i, val, 2 * idx + 2, mid + 1, r);
t[idx] = t[2 * idx + 1] + t[2 * idx + 2];
}
int sum(int l, int r, int idx, int tl, int tr)
{
if (l == tl && r == tr)
return t[idx];
int mid = (tl + tr) / 2;
if (r <= mid)
return sum(l, r, 2 * idx + 1, tl, mid);
if (l > mid)
return sum(l, r, 2 * idx + 2, mid + 1, tr);
return sum(l, mid, 2 * idx + 1, tl, mid) +
sum(mid + 1, r, 2 * idx + 2, mid + 1, tr);
}
public:
SegmentTree(const vector<int>& a)
: t(4 * a.size())
, n(a.size())
{
build(0, 0, a.size() - 1, a);
}
void update(int i, int val)
{
update(i, val, 0, 0, n - 1);
}
int sum(int l, int r)
{
return sum(l, r, 0, 0, n - 1);
}
};
更紧凑的实现
再看线段树的构造过程,我们可以确定其构造出的二叉树是一棵满二叉树(所有结点要么是叶子结点,要么有两个孩子),且所有叶子结点表示单个元素区间,也就是说有$n$个叶子节点。这说明该二叉树的结点数量为$2n-1$(可简单归纳证明)。
而我们之前对树的存储方式下,由于二叉树最后一层会产生空缺,而且不像二叉堆一样是连续的,导致我们不得不使用一个$4n$的数组来存储以应对最坏情况。
我们可以修改原先的存储方式,原先我们是按照层遍历的顺序存储对应位置的结点于数组中,现在考虑按照先序遍历的顺序进行存储。
考虑在索引idx
下的一个结点,其代表区间[l:r]
,我们用索引idx+1
来存储其左孩子,此时其左子树代表区间[l:mid]
,因此左子树占据的结点数量为$2(mid-l+1)-1$,据此我们可以计算右孩子的索引为$idx+2(mid-l+1)$。以此类推,则二叉树在数组中的存储是紧凑的,只需要$2n-1$长度的数组即可存储。
下面的代码仅将左右子树根结点在数组中的索引计算修改,即可达到优化效果。
class SegmentTree {
vector<int> t;
int n;
int build(int idx, int l, int r, const vector<int>& a)
{
if (l == r) {
return t[idx] = a[l];
}
int mid = (l + r) / 2;
return t[idx] = build(idx + 1, l, mid, a) +
build(idx + 2 * (mid - l + 1), mid + 1, r, a);
}
void update(int i, int val, int idx, int l, int r)
{
if (l == r) {
t[idx] = val;
return;
}
int mid = (l + r) / 2;
if (i <= mid)
update(i, val, idx + 1, l, mid);
else
update(i, val, idx + 2 * (mid - l + 1), mid + 1, r);
t[idx] = t[idx + 1] + t[idx + 2 * (mid - l + 1)];
}
int sum(int l, int r, int idx, int tl, int tr)
{
if (l == tl && r == tr)
return t[idx];
int mid = (tl + tr) / 2;
if (r <= mid)
return sum(l, r, idx + 1, tl, mid);
if (l > mid)
return sum(l, r, idx + 2 * (mid - tl + 1), mid + 1, tr);
return sum(l, mid, idx + 1, tl, mid) +
sum(mid + 1, r, idx + 2 * (mid - tl + 1), mid + 1, tr);
}
public:
SegmentTree(const vector<int>& a)
: t(2 * a.size() - 1)
, n(a.size())
{
build(0, 0, a.size() - 1, a);
}
void update(int i, int val)
{
update(i, val, 0, 0, n - 1);
}
int sum(int l, int r)
{
return sum(l, r, 0, 0, n - 1);
}
};
应用
求区间最值
这是最基本的应用,只要把上面的左右子树的merge过程修改为求子树区间的最值(而非求和)即可。
求区间最大值及其出现次数
在最大值多次出现的时候,我们还想计算其出现的次数。这个问题可以通过在结点处同时存储子树的最大值及出现次数,同时修改merge的过程来处理。
求区间内最大公约数或最小公倍数
这也是很简单的变体,只需修改merge过程为求子树的gcd或lcm即可。
求数组中出现的第k个零的位置
不止是零,计算第k
个任意数均可以通过修改merge过程,使树结点维护区间内目标元素出现的次数,这样即可根据当前子树下求解的排位k
递归计算。
int find_kth(int idx, int l, int r, int k) {
if (k > t[idx]) // 当前区间内零的数量不足k
return -1;
if (l == r)
return l; // 找到了
int mid = (l + r) / 2;
if (t[idx+1] >= k) // 第k个在左子树
return find_kth(idx+1, l, mid, k);
else // 第k个在右子树
return find_kth(idx+2*(mid-l+1), mid+1, r, k - t[idx+1]);
}
求大于给定值的最短前缀
给定一个正的数组a
,以及一个值k
,求最小的索引i
,使得前缀a[0:i]
的和大于k
。
同样地,这个问题可以通过维护区间和来处理,我们递归计算当前区间[l:r]
下第一个大于k
的前缀位置,若其左子树区间和大于k
,则递归进入左子树,否则进入右子树,且目标值减去左子树之和。
求区间内第一个大于k的元素
给定一个数组a
,求区间[l:r]
内第一个大于k
的元素a[i]
。
这个问题通过维护区间最大值来处理,考虑当前区间的左子树,若其最大值大于k,则递归进入左子树;否则递归进入右子树。
求最大的子区间和
给定一个数组a
,求区间[l:r]
内具有最大区间和的子区间a[i:j]
,其中l<=i<=j<=r
。
这个问题的求解略复杂,我们需要在树结点中存储如下四个信息:
- 区间和
- 区间的最大前缀和
- 区间的最大后缀和
- 最大子区间和
然后我们递归构建之,首先计算出左子树和右子树下的这组值,则当前位置的最大子区间和能够只有如下三个情况:
- 左半区间的最大子区间和
- 右半区间的最大子区间和
- 左半区间的最大后缀和加上右半区间的最大前缀和
构建完成后,即可通过递归方式求解指定区间内的最大子区间和,算法和求区间最值一致,只是在组合左右子树的返回值时以上面的思想来计算结果。