정화 코딩

[C++] 구간 합 구하기 (백준 2042번) - 세그먼트 트리 정리 본문

PS

[C++] 구간 합 구하기 (백준 2042번) - 세그먼트 트리 정리

jungh150c 2025. 4. 24. 11:14

https://www.acmicpc.net/problem/2042

세그먼트 트리란?

세그먼트 트리는 구간에 대한 정보를 트리 구조로 저장하여, 구간의 갱신과 조회를 효율적으로 처리할 수 있도록 설계된 자료구조이다.

세그먼트 트리의 원리

1번 노드는 1~8 구간의 정보를 담고 있고

2번 노드는 1~4 구간의 정보를, 3번 노드는 5~8 구간의 정보를 담고 있고

...

이런 구조로 되어 있다.

 

3~7 구간의 정보를 알고 싶다고 생각해보자. 

우선 루트 노드에서 시작해서 쭉 내려간다. 

- 내 구간이 목표 구간에 완벽히 포함되지 않으면 절반으로 쪼개서 자식 노드로 전달한다.

- 만약 완벽히 포함된다면 내가 가지고 있는 정보를 그대로 위로 다시 올리면 되고, 만약 전혀 겹치지 않는다면 항등원을 반환한다. 

세그먼트 트리 구현

아래 글은 일반적인 상황에 대해 설명이고, 코드는 구간 합 문제에 대한 코드이다.

어떤 정보를 저장하는 세그먼트 트리인지에 따라 트리의 연산 방식이 달라질 수 있다. 따라서 구조는 비슷하게 사용하되, 필요한 부분만 바꿔서 사용하면 된다. 

1. init 함수 - 트리 초기화

long long init(int idx, int l, int r) {
    if (l == r) return tree[idx] = arr[l];
    int m = (l + r) / 2;
    return tree[idx] = init(idx * 2, l, m) + init(idx * 2 + 1, m + 1, r);
}

입력 배열 arr을 기반으로 트리 배열 tree를 초기화하고, 각 노드에 구간 정보를 저장한다. 

2. update 함수 - 값 갱신

long long update(int idx, int l, int r, int target, long long val) {
    if (target < l || target > r) return tree[idx];
    if (l == r) return tree[idx] = val;
    int m = (l + r) / 2;
    return tree[idx] = update(idx * 2, l, m, target, val) + update(idx * 2 + 1, m + 1, r, target, val);
}

특정 인덱스 값을 변경했을 때, 트리의 값을 함께 갱신한다. 

- 해당 인덱스가 현재 노드의 구간에 포함되지 않으면 무시한다.

- 리프 노드에 도달하면 값을 바꾸고, 부모 노드는 자식 노드의 변경된 값을 반영하여 자신의 값을 갱신한다. 

3. query 함수 - 구간 정보 조회

long long query(int idx, int l, int r, int wl, int wr) {
    if (wr < l || wl > r) return 0;
    if (wl <= l && wr >= r) return tree[idx];
    int m = (l + r) / 2;
    return query(idx * 2, l, m, wl, wr) + query(idx * 2 + 1, m + 1, r, wl, wr);
}

주어진 범위 [wl, wr] 에 대한 구간 정보를 구한다.

- 현재 노드의 구간이 아예 겹치지 않으면 항등원을 반환한다.

- 현재 노드의 구간이 요청한 범위에 완전히 포함되면 그 값을 바로 반환한다.

- 일부만 겹치면 자식 노드로 분할하여 다시 요청을 보낸다. 

 

#include <iostream>
#include <vector>
using namespace std;

int n, m, k;
vector<long long> arr;
vector<long long> tree;

long long init(int idx, int l, int r) {
    if (l == r) return tree[idx] = arr[l];
    int m = (l + r) / 2;
    return tree[idx] = init(idx * 2, l, m) + init(idx * 2 + 1, m + 1, r);
}

long long update(int idx, int l, int r, int target, long long val) {
    if (target < l || target > r) return tree[idx];
    if (l == r) return tree[idx] = val;
    int m = (l + r) / 2;
    return tree[idx] = update(idx * 2, l, m, target, val) + update(idx * 2 + 1, m + 1, r, target, val);
}

long long query(int idx, int l, int r, int wl, int wr) {
    if (wr < l || wl > r) return 0;
    if (wl <= l && wr >= r) return tree[idx];
    int m = (l + r) / 2;
    return query(idx * 2, l, m, wl, wr) + query(idx * 2 + 1, m + 1, r, wl, wr);
}

long long update(int target, long long val) {
    return update(1, 1, n, target, val);
}

long long query(int wl, int wr) {
    return query(1, 1, n, wl, wr);
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m >> k;
    int tc = m + k;

    tree.assign(4 * n + 1, 0);
    arr.resize(n + 1);

    for (int i = 1; i < n + 1; i++) cin >> arr[i];

    init(1, 1, n);

    while (tc--) {
        long long a, b, c;
        cin >> a >> b >> c;

        if (a == 1) update(b, c);
        else if (a == 2) cout << query(b, c) << '\n';
    }
}

(AC)

 

Comments