线段树


线段树模板


核心结构与变量定义

通常用一个数组来模拟这棵树。如果一个父节点的数组下标是 pp,那么它的左子节点下标是 2p2p,右子节点是 2p+12p+1。这要求我们从下标 1 开始存储 因此,如果原始数组的大小是 NN,为了安全地存储整棵树,我们通常会开一个大小为 4N4N 的数组

tree[p]tree[p]:存储节点 pp 所代表的区间的总和

lazy[p]lazy[p]:存储节点 pp 的懒惰标记。标记的含义是“这个区间的所有元素都需要加上 这个值,但我们暂时还没有把它传递给子节点”

aa: 原始输入数组


    long long a[100005];
    long long tree[100005 * 4];
    long long lazy[100005 * 4];
    int n;

push_up (信息上传)

这个函数非常简单,它的作用是用子节点的信息更新父节点。对于区间求和来说,父节点的和就等于它两个子节点的和 这个操作在建树和更新的递归回溯过程中会被频繁调用

功能:用左右子节点的值更新父节点

pp: 当前节点在 treetree 数组中的索引

void push_up(int p)
{
    tree[p] = tree[p * 2] + tree[p * 2 + 1];
}

build (建树)

建树是一个递归的过程,将原始数组 aa 的信息构建到 treetree 数组中

  • 递归终点:当区间 [l,r][l, r]l=rl = r 时,说明到达了叶子节点,它代表原始数组中的单个元素。
  • 递归过程: 找到当前区间 [l,r][l, r] 的中点 mid。 递归地为左子区间 [l,mid][l, mid] 和右子区间 [mid+1,r][mid+1, r] 建树 建完子树后,用 push_up 函数更新当前节点的值。

功能:构建线段树

pp: 当前节点在 treetree 数组中的索引

l,rl, r: 当前节点所代表的区间 [l,r][l, r]

调用入口: build(1, 1, n);

void build(int p, int l, int r)
{
    if (l == r)
    {
        // 到达叶子节点,直接赋值
        tree[p] = a[l];
        return;
    }
    int mid = l + (r - l) / 2;    // 计算中点,防止整数溢出
    build(p * 2, l, mid);         // 递归构建左子树
    build(p * 2 + 1, mid + 1, r); // 递归构建右子树
    push_up(p);                   // 用子节点信息更新当前节点
}

push_down (懒惰标记下传)

当我们需要更新或查询一个经过带有懒惰标记的节点的子节点时,我们必须先将这个标记传递下去

pp: 当前节点在 treetree 数组中的索引

l,rl, r : 当前节点所代表的区间 [l,r][l, r]

功能:将父节点 pp 的懒惰标记 lazy[p]lazy[p] 应用到它的两个子节点上

步骤:

检查当前节点 pp 是否有懒惰标记。如果没有,直接返回

更新左子节点 2p2ptreetree 值和 lazylazy 标记。treetree 值增加的大小是 lazy[p]lazy[p] 乘以左子区间的长度

更新右子节点 2p+12p+1treetree 值和 lazylazy 标记

清除当前节点 pp 的懒惰标记(因为它已经成功“下传”了)

void push_down(int p, int l, int r)
{
    if (lazy[p] != 0)
    { // 如果有标记
        int mid = l + (r - l) / 2;
        int len_left = mid - l + 1; // 左子区间的长度
        int len_right = r - mid;    // 右子区间的长度

        // 更新左子节点的 tree 值和 lazy 标记
        tree[p * 2] += lazy[p] * len_left;
        lazy[p * 2] += lazy[p];

        // 更新右子节点的 tree 值和 lazy 标记
        tree[p * 2 + 1] += lazy[p] * len_right;
        lazy[p * 2 + 1] += lazy[p];

        // 清除当前节点的标记
        lazy[p] = 0;
    }
}

update (区间修改)

[updatel,updater][update_l, update_r] 区间内的每个数都加上 valval 递归逻辑:

  1. 完全包含:如果当前节点代表的区间 [l,r][l, r] 被要修改的区间 [updatel,updater][update_l, update_r] 完全包含,我们直接更新当前节点的 treetree 值,并给它打上懒惰标记 lazylazy,然后返回。我们不需要再往下递归,这就是“懒惰”的体现
  2. 部分相交:如果当前区间与目标区间有交集,但不是完全包含: 先 push_down,确保子节点的状态是最新的。 递归地到可能相交的左子节点或右子节点去执行 update 子节点更新完毕后,push_up 更新当前节点
  3. 完全不相交:直接返回 调用入口: update(1, 1, n, start, end, value)
void update(int p, int l, int r, int update_l, int update_r, long long val)
{
// 情况 3: 完全不相交
// 如果当前区间与修改区间无交集,直接返回,什么都不做。
    if (r < update_l || l > update_r)
    {
        return;
    }

// 情况 1: 完全包含
// 如果当前区间被修改区间完全包含,则更新当前节点并打上懒惰标记。
    if (update_l <= l && r <= update_r)
    {
        tree[p] += val * (r - l + 1);
        lazy[p] += val;
        return;
    }

// 情况 2: 部分相交
// 先下传懒惰标记,然后递归到子节点进行更新。
    push_down(p, l, r);

    int mid = l + (r - l) / 2;
    update(p * 2, l, mid, update_l, update_r, val);
    update(p * 2 + 1, mid + 1, r, update_l, update_r, val);

// 不要忘记用更新后的子节点信息来更新当前节点。
    push_up(p);
}

query (区间查询)

查询与更新的逻辑非常相似 功能:区间查询,返回 [query_l, query_r] 的和 递归逻辑:

  1. 完全包含:如果当前节点代表的区间 [l, r] 被查询区间 [query_l, query_r] 完全包含,直接返回当前节点的 tree 值
  2. 部分相交: 先 push_down,确保子节点的值是最新的 递归地到相交的左、右子节点去查询,并将结果累加起来
  3. 完全不相交:返回一个不影响结果的值(对于求和是 00,求最大值是 -\infty ) 调用入口: query(1, 1, n, start, end);
long long query(int p, int l, int r, int query_l, int query_r)
{
	// 情况 2: 完全不相交
	// 如果当前区间 [l, r] 与查询区间 [query_l, query_r] 没有任何交集,
	// 那么当前子树对结果没有任何贡献,返回单位元 0。
	if (r < query_l || l > query_r)
		return 0; // 对于求和问题,不影响结果的值是 0

	// 情况 1: 完全包含
	// 如果当前区间被查询区间完全包含,直接返回当前节点的值。
	if (query_l <= l && r <= query_r)
		return tree[p];

	// 情况 3: 部分相交
	// 必须先下传懒惰标记,确保子节点的信息是正确的。
	push_down(p, l, r);

	int mid = l + (r - l) / 2;
	long long sum = 0;

	// 递归查询左右子节点,并将结果累加。
	// 注意:这里不再需要 if 判断,因为上面“完全不相交”的检查会处理好一切。
	// 如果查询区间只与一个子区间相交,另一个子区间的递归调用会因为不相交而立刻返回 0。
	sum += query(p * 2, l, mid, query_l, query_r);
	sum += query(p * 2 + 1, mid + 1, r, query_l, query_r);
	return sum;
}

线段树模板题


题目描述

如题,已知一个数列 {ai}\{a_i\},你需要进行下面两种操作:

  1. 将某区间每一个数加上 kk
  2. 求出某区间每一个数的和。

输入格式

第一行包含两个整数 n,mn, m,分别表示该数列数字的个数和操作的总个数。

第二行包含 nn 个用空格分隔的整数 aia_i,其中第 ii 个数字表示数列第 ii 项的初始值。

接下来 mm 行每行包含 3344 个整数,表示一个操作,具体如下:

  1. 1 x y k:将区间 [x,y][x, y] 内每个数加上 kk
  2. 2 x y:输出区间 [x,y][x, y] 内每个数的和。

输出格式

输出包含若干行整数,即为所有操作 2 的结果。

输入输出样例 #1

输入 #1
5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4
输出 #1
11
8
20

说明/提示

对于 15%15\% 的数据:n8n \le 8m10m \le 10
对于 35%35\% 的数据:n103n \le {10}^3m104m \le {10}^4
对于 100%100\% 的数据:1n,m1051 \le n, m \le {10}^5ai,ka_i,k 为正数,且任意时刻数列的和不超过 2×10182\times 10^{18}

#include <bits/stdc++.h>
const int maxn = 1e5 + 5;
long long n, a[maxn], w[maxn << 2], tag[maxn << 2];
void build(long long l, long long r, long long u)
{
    if (l == r)
    {
        w[u] = a[l];
        return;
    }
    long long mid = l + ((r - l) >> 1);
    build(l, mid, u << 1);
    build(mid + 1, r, (u << 1) | 1);
    w[u] = w[u << 1] + w[(u << 1) | 1];
}
void update(long long l, long long r, long long c, long long L, long long R, long long u)
{
    if (l <= L && R <= r)
    {
        w[u] += (R - L + 1) * c;
        tag[u] += c;
        return;
    }
    long long mid = L + ((R - L) >> 1);
    if (tag[u])
    {
        w[u << 1] += tag[u] * (mid - L + 1);
        w[(u << 1) | 1] += tag[u] * (R - mid);
        tag[u << 1] += tag[u], tag[(u << 1) | 1] += tag[u];
    }
    tag[u] = 0;
    if (l <= mid)
        update(l, r, c, L, mid, u << 1);
    if (r > mid)
        update(l, r, c, mid + 1, R, (u << 1) | 1);
    w[u] = w[u << 1] + w[(u << 1) | 1];
}
long long getsum(long long l, long long r, long long L, long long R, long long u)
{
    if (l <= L && R <= r)
        return w[u];
    long long mid = L + ((R - L) >> 1);
    if (tag[u])
    {
        w[u << 1] += tag[u] * (mid - L + 1);
        w[(u << 1) | 1] += tag[u] * (R - mid);
        tag[u << 1] += tag[u];
        tag[(u << 1) | 1] += tag[u];
    }
    tag[u] = 0;
    long long sum = 0;
    if (l <= mid)
        sum = getsum(l, r, L, mid, u << 1);
    if (r > mid)
        sum += getsum(l, r, mid + 1, R, (u << 1) | 1);
    return sum;
}
int main()
{
    long long q, op, x, y, k;
    scanf("%lld%lld", &n, &q);
    for (long long i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    build(1, n, 1);
    while (q--)
    {
        scanf("%lld%lld%lld", &op, &x, &y);
        if (op == 1)
        {
            scanf("%lld", &k);
            update(x, y, k, 1, n, 1);
        }
        else
            printf("%lld\n", getsum(x, y, 1, n, 1));
    }
    return 0;
}