线段树(Segment Tree)是一种在数组之上维护的树结构(也叫statistic tree),它能够存储一个数组a,支持下面两个操作:

  1. 以$O(lgn)$的时间查询数组上的一个区间a[l:r]内的某种统计值,比如查询区间内的元素之和,或区间内的最小值最大值等。
  2. 支持以$O(lgn)$的时间修改数组中的某一元素。

基础思想

本节按照求区间和需求解释如何构建并维护一棵线段树。其描述如下:

  1. 对于数组a[0:n-1],树的根节点存储整个区间的和,即sum(a[0:n-1]),等价于其维护区间[0:n-1]
  2. 树的任意中间结点,若维护区间[l:r],则其左子树维护[l:(l+r)/2],右子树维护区间[(l+r)/2+1:r]
  3. l==r,则此为叶子节点,不再分裂。

这些树节点的存储可以借鉴二叉堆的思想,用一个数组t来存各层的树结点:根结点在t[0],其左孩子位于t[1],右孩子位于t[2]……以此类推,即t[i]的左孩子为t[2i+1],右孩子为t[2i+2]

复杂度分析

树高

显然树的构建可以递归的完成,因为每次分裂均对区间长度做二分处理,因此树的高度在$lgn$这一量级;

数据更新

若需要更新数组a[i]=x,则显而易见的做法是从树根递归向下直到对应a[i]的叶子结点,将受影响的区间和进行修改,而每一层的所有区间是不相交的,仅会有一个区间包含a[i],因此这一过程时间复杂度$O(lgn)$。

求区间和

求区间a[l:r]的和时,依然是从树根开始递归向下,只是这里可能会出现触发多次递归的情况:

  1. 若当前结点维护的区间[tl:tr]恰好等于[l:r],则直接返回其和;
  2. 若当前区间[tl:tr][l:r]的左半区间,或右半区间包含,则递归至对应子树;
  3. 若当前区间与左右半区间均有交集,记中间点为m,则依次递归左子树查询区间[l:m],递归右子树查询区间[m+1,r],再求和。

我们证明上面的递归过程中访问的结点数量不超过$4lgn$,从而时间复杂度为$O(lgn)$。

实际上,我们只需要证明树的每一层中访问到的结点数量不超过4即可。通过归纳法证明,首先在第一层(树根)显然仅1个结点。其次,对于任意层:

  1. 若访问了2个结点,则下一层最多访问4个结点,因为每个结点最多触发两次递归;
  2. 若访问了3或4个结点,则这些结点构成的区间必然是相连的,中间的1个或2个结点必然恰好等于查询的区间,不会触发下一层的递归,仅有最左侧和最右侧的结点可能触发递归,而这也导致下一层不超过4个;

基础实现

对于上面提到的实现,长度$n$的数组,其产生的树高为$\lceil lgn \rceil+1$,因此数组大小至多$4n$。通过递归可以完成所有操作。

class SegmentTree {
    vector<int> t;
    int n;

    int build(int idx, int l, int r, const vector<int>& a)
    {
        if (l == r)
            return t[idx] = a[l];
        int mid = (l + r) / 2;
        return t[idx] = build(2 * idx + 1, l, mid, a) + 
                        build(2 * idx + 2, mid + 1, r, a);
    }

    void update(int i, int val, int idx, int l, int r)
    {
        if (l == r) {
            t[idx] = val;
            return;
        }
        int mid = (l + r) / 2;
        if (i <= mid)
            update(i, val, 2 * idx + 1, l, mid);
        else
            update(i, val, 2 * idx + 2, mid + 1, r);
        t[idx] = t[2 * idx + 1] + t[2 * idx + 2];
    }
    int sum(int l, int r, int idx, int tl, int tr)
    {
        if (l == tl && r == tr)
            return t[idx];
        int mid = (tl + tr) / 2;
        if (r <= mid)
            return sum(l, r, 2 * idx + 1, tl, mid);
        if (l > mid)
            return sum(l, r, 2 * idx + 2, mid + 1, tr);
        return sum(l, mid, 2 * idx + 1, tl, mid) +
               sum(mid + 1, r, 2 * idx + 2, mid + 1, tr);
    }

public:
    SegmentTree(const vector<int>& a)
        : t(4 * a.size())
        , n(a.size())
    {
        build(0, 0, a.size() - 1, a);
    }

    void update(int i, int val)
    {
        update(i, val, 0, 0, n - 1);
    }

    int sum(int l, int r)
    {
        return sum(l, r, 0, 0, n - 1);
    }
};

更紧凑的实现

再看线段树的构造过程,我们可以确定其构造出的二叉树是一棵满二叉树(所有结点要么是叶子结点,要么有两个孩子),且所有叶子结点表示单个元素区间,也就是说有$n$个叶子节点。这说明该二叉树的结点数量为$2n-1$(可简单归纳证明)。

而我们之前对树的存储方式下,由于二叉树最后一层会产生空缺,而且不像二叉堆一样是连续的,导致我们不得不使用一个$4n$的数组来存储以应对最坏情况。

我们可以修改原先的存储方式,原先我们是按照层遍历的顺序存储对应位置的结点于数组中,现在考虑按照先序遍历的顺序进行存储。

考虑在索引idx下的一个结点,其代表区间[l:r],我们用索引idx+1来存储其左孩子,此时其左子树代表区间[l:mid],因此左子树占据的结点数量为$2(mid-l+1)-1$,据此我们可以计算右孩子的索引为$idx+2(mid-l+1)$。以此类推,则二叉树在数组中的存储是紧凑的,只需要$2n-1$长度的数组即可存储。

下面的代码仅将左右子树根结点在数组中的索引计算修改,即可达到优化效果。

class SegmentTree {
    vector<int> t;
    int n;

    int build(int idx, int l, int r, const vector<int>& a)
    {
        if (l == r) {
            return t[idx] = a[l];
        }
        int mid = (l + r) / 2;
        return t[idx] = build(idx + 1, l, mid, a) + 
                        build(idx + 2 * (mid - l + 1), mid + 1, r, a);
    }

    void update(int i, int val, int idx, int l, int r)
    {
        if (l == r) {
            t[idx] = val;
            return;
        }
        int mid = (l + r) / 2;
        if (i <= mid)
            update(i, val, idx + 1, l, mid);
        else
            update(i, val, idx + 2 * (mid - l + 1), mid + 1, r);
        t[idx] = t[idx + 1] + t[idx + 2 * (mid - l + 1)];
    }

    int sum(int l, int r, int idx, int tl, int tr)
    {
        if (l == tl && r == tr)
            return t[idx];
        int mid = (tl + tr) / 2;
        if (r <= mid)
            return sum(l, r, idx + 1, tl, mid);
        if (l > mid)
            return sum(l, r, idx + 2 * (mid - tl + 1), mid + 1, tr);
        return sum(l, mid, idx + 1, tl, mid) + 
               sum(mid + 1, r, idx + 2 * (mid - tl + 1), mid + 1, tr);
    }

public:
    SegmentTree(const vector<int>& a)
        : t(2 * a.size() - 1)
        , n(a.size())
    {
        build(0, 0, a.size() - 1, a);
    }

    void update(int i, int val)
    {
        update(i, val, 0, 0, n - 1);
    }

    int sum(int l, int r)
    {
        return sum(l, r, 0, 0, n - 1);
    }
};

应用

求区间最值

这是最基本的应用,只要把上面的左右子树的merge过程修改为求子树区间的最值(而非求和)即可。

求区间最大值及其出现次数

在最大值多次出现的时候,我们还想计算其出现的次数。这个问题可以通过在结点处同时存储子树的最大值及出现次数,同时修改merge的过程来处理。

求区间内最大公约数或最小公倍数

这也是很简单的变体,只需修改merge过程为求子树的gcd或lcm即可。

求数组中出现的第k个零的位置

不止是零,计算第k个任意数均可以通过修改merge过程,使树结点维护区间内目标元素出现的次数,这样即可根据当前子树下求解的排位k递归计算。

int find_kth(int idx, int l, int r, int k) {
    if (k > t[idx]) // 当前区间内零的数量不足k
        return -1;
    if (l == r)
        return l; // 找到了
    int mid = (l + r) / 2;
    if (t[idx+1] >= k) // 第k个在左子树
        return find_kth(idx+1, l, mid, k);
    else  // 第k个在右子树
        return find_kth(idx+2*(mid-l+1), mid+1, r, k - t[idx+1]);
}

求大于给定值的最短前缀

给定一个正的数组a,以及一个值k,求最小的索引i,使得前缀a[0:i]的和大于k

同样地,这个问题可以通过维护区间和来处理,我们递归计算当前区间[l:r]下第一个大于k的前缀位置,若其左子树区间和大于k,则递归进入左子树,否则进入右子树,且目标值减去左子树之和。

求区间内第一个大于k的元素

给定一个数组a,求区间[l:r]内第一个大于k的元素a[i]

这个问题通过维护区间最大值来处理,考虑当前区间的左子树,若其最大值大于k,则递归进入左子树;否则递归进入右子树。

求最大的子区间和

给定一个数组a,求区间[l:r]内具有最大区间和的子区间a[i:j],其中l<=i<=j<=r

这个问题的求解略复杂,我们需要在树结点中存储如下四个信息:

  1. 区间和
  2. 区间的最大前缀和
  3. 区间的最大后缀和
  4. 最大子区间和

然后我们递归构建之,首先计算出左子树和右子树下的这组值,则当前位置的最大子区间和能够只有如下三个情况:

  1. 左半区间的最大子区间和
  2. 右半区间的最大子区间和
  3. 左半区间的最大后缀和加上右半区间的最大前缀和

构建完成后,即可通过递归方式求解指定区间内的最大子区间和,算法和求区间最值一致,只是在组合左右子树的返回值时以上面的思想来计算结果。