알고리즘/코드

세그먼트 트리 / Segment Tree / Segment Tree with Lazy Propagation C++ 코드

경우42 2024. 6. 6. 16:28
반응형
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

vector<long long> tree;
vector<long long> a;
int N, Q;

// 세그먼트 트리 구축
void build() {
    for (int i = 0; i < N; ++i) {
        tree[N + i] = a[i];
    }
    for (int i = N - 1; i > 0; --i) {
        tree[i] = tree[i << 1] + tree[i << 1 | 1];
    }
}

// 세그먼트 트리의 특정 위치 값 업데이트
void updateTree(int where, long long value) {
    where += N;
    tree[where] = value;

    while (where > 1) {
        where >>= 1;
        tree[where] = tree[where << 1] + tree[where << 1 | 1];
    }
}

// 세그먼트 트리 구간 합 쿼리
long long query(int left, int right) {
    long long sum = 0;
    left += N;
    right += N;

    while (left <= right) {
        if (left & 1) sum += tree[left++];
        if (!(right & 1)) sum += tree[right--];
        left >>= 1;
        right >>= 1;
    }
    return sum;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);

    cin >> N >> Q;
    a.resize(N);
    tree.resize(2 * N);

    for (int i = 0; i < N; ++i) {
        cin >> a[i];
    }

    build();

    for (int i = 0; i < Q; ++i) {
        int x, y, a_idx;
        long long b;
        cin >> x >> y >> a_idx >> b;
        if (x > y) swap(x, y);
        cout << query(x - 1, y - 1) << '\n';
        updateTree(a_idx - 1, b);
    }

    return 0;
}

 

Segment Tree  with Lazy Propagation

#include <bits/stdc++.h>
using namespace std;

static const int MAXN = 100000; // 필요에 따라 조정
long long arr[MAXN];          // 원본 배열
long long seg[4 * MAXN];      // 세그먼트 트리 배열
long long lazy[4 * MAXN];     // 레이지 배열(지연 업데이트 정보)
int N, Q;

// 트리 구축: O(N)
void buildTree(int idx, int start, int end) {
    if (start == end) {
        // 리프 노드: arr[start] 값으로 초기화
        seg[idx] = arr[start];
        return;
    }
    int mid = (start + end) / 2;
    buildTree(idx * 2,     start,     mid);
    buildTree(idx * 2 + 1, mid + 1,   end);

    // 부모 노드는 자식 노드 합으로 구성
    seg[idx] = seg[idx * 2] + seg[idx * 2 + 1];
}

// lazy 값을 실제 seg 트리에 반영하고,
// 자식에게 lazy 값을 넘겨주는 함수
void propagate(int idx, int start, int end) {
    // lazy[idx] != 0 이면, 아직 처리되지 않은 업데이트가 남아있다는 뜻
    if (lazy[idx] != 0) {
        // 현재 구간의 크기만큼 lazy[idx] 값을 더해줌
        seg[idx] += (end - start + 1) * lazy[idx];

        // 리프 노드가 아니라면, 자식 노드의 lazy 값에 누적
        if (start != end) {
            lazy[idx * 2]     += lazy[idx];
            lazy[idx * 2 + 1] += lazy[idx];
        }
        // 현재 노드의 lazy 정보는 반영이 끝났으므로 0으로 초기화
        lazy[idx] = 0;
    }
}

// 구간 [l, r] 에 val 을 더하는 업데이트
void updateRange(int idx, int start, int end, int l, int r, long long val) {
    // 우선 현재 노드에 혹시 남아있는 lazy 값이 있다면 반영
    propagate(idx, start, end);

    // [start, end]와 [l, r]가 전혀 겹치지 않는 경우
    if (end < l || r < start) {
        return;
    }

    // [start, end]가 [l, r]에 완전히 포함되는 경우 (전부 커버)
    if (l <= start && end <= r) {
        // lazy 배열에만 값을 더해 두고, propagate해서 반영
        lazy[idx] += val;
        propagate(idx, start, end); 
        return;
    }

    // 겹치는 구간이지만 완전히 포함되지 않는 경우: 자식에게 내려가서 처리
    int mid = (start + end) / 2;
    updateRange(idx * 2,     start,     mid,     l, r, val);
    updateRange(idx * 2 + 1, mid + 1,   end,     l, r, val);

    // 자식 노드의 업데이트가 끝나면, 부모 노드 값 갱신
    seg[idx] = seg[idx * 2] + seg[idx * 2 + 1];
}

// 구간 [l, r]의 합을 구하는 쿼리
long long queryRange(int idx, int start, int end, int l, int r) {
    // 우선 현재 노드의 lazy 값 반영
    propagate(idx, start, end);

    // [start, end]와 [l, r]가 겹치지 않으면 0 리턴
    if (end < l || r < start) {
        return 0;
    }

    // [start, end]가 [l, r]에 완전히 포함될 때
    if (l <= start && end <= r) {
        return seg[idx];
    }

    // 겹치는 구간
    int mid = (start + end) / 2;
    return queryRange(idx * 2,     start, mid,     l, r)
         + queryRange(idx * 2 + 1, mid+1, end,     l, r);
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> N >> Q;
    for (int i = 0; i < N; i++) {
        cin >> arr[i];
    }

    // 세그먼트 트리 초기화
    buildTree(1, 0, N - 1);

    while (Q--) {
        int t;
        cin >> t;
        if (t == 1) {
            // 예: "구간 [l, r] 에 val 더하기" 쿼리
            int l, r;
            long long val;
            cin >> l >> r >> val;
            // 문제에서 1-based 입력이면 0-based로 조정
            l--; 
            r--;
            if (l > r) swap(l, r); 
            updateRange(1, 0, N - 1, l, r, val);
        } 
        else {
            // 예: "구간 [l, r] 합 구하기" 쿼리
            int l, r;
            cin >> l >> r;
            l--;
            r--;
            if (l > r) swap(l, r);
            cout << queryRange(1, 0, N - 1, l, r) << "\n";
        }
    }

    return 0;
}
반응형