抽象化非再帰セグ木のシンプルな実装(C++)

この記事は

  • 再帰セグ木、再帰セグ木よりもシンプルに書けるよ
  • 抽象化、怖くないよ

という記事です。セグ木の初歩についてはあまり触れていません。
ですので、まず最初に、私がセグ木の勉強をした際に大変参考になった記事・スライド・動画をご紹介します。まだセグ木に関して何も知らないという方は、それらで勉強してから改めてこの記事を読んでいただけると嬉しいです。

セグ木入門資料

  • ふるやんさんのセグメントツリー入門
    会話形式でするするっと頭に入ってきます。セグ木の初めの一歩におすすめです。 ただし、ノードのindexの振り方は本記事と異なり、また再帰による実装なので、本記事を読む際に混乱しないよう注意してください。

  • えびちゃんさんのスライド
    https://hcpc-hokudai.github.io/archive/structure_segtree_001.pdf
    再帰セグ木について詳しく書かれています。図が綺麗で、とても分かりやすいです。私の本記事を読むに当たってはこのスライドの入門編までの理解で十分ですが、セグ木に慣れてきた頃に応用編を読むとより理解が深まります。

  • かつっぱさんの「Segment木ってなに〜?」シリーズ
    https://www.youtube.com/watch?v=LjhVy1ZJTMc
    百聞は一見に如かず派の人におすすめです。セグ木の種類から仕組み、実装まで丁寧に解説されています。 私はこの動画で、各種セグ木の違いについて初めて理解しました。

とりあえず非再帰セグ木のコードを見てみよう

onlinejudge.u-aizu.ac.jp
まずはRMQ(Range Minimum Query)を解きます。

#include <iostream>
#include <vector>
#include <climits>

using namespace std;

struct segment_tree {
    int n;
    vector<int> node;
    segment_tree(int n) : n(n), node(n<<1, INT_MAX) {}
    void set(int i, int x) {
        node[i += n] = x;
        while (i >>= 1) node[i] = min(node[i<<1|0], node[i<<1|1]);
    }
    int fold(int l, int r) {
        int res = INT_MAX;
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) res = min(res, node[l++]);
            if (r & 1) res = min(node[--r], res);
        }
        return res;
    }
};

int main() {
    int n, q;
    cin >> n >> q;
    segment_tree seg(n);
    while (q--) {
        int com, x, y;
        cin >> com >> x >> y;
        if (com) cout << seg.fold(x, y+1) << '\n';
        else seg.set(x, y);
    }
}

解けました。
セグ木部分についてはなんとたったの17行です。こんなにも簡単に書けちゃうんですね。

もう一つ解きましょう。次はこちら。
onlinejudge.u-aizu.ac.jp
RSQ(Range Sum Query)です。

#include <iostream>
#include <vector>

using namespace std;

struct segment_tree {
    int n;
    vector<int> node;
    segment_tree(int n) : n(n), node(n<<1, 0) {} // 初期値が0になりました
    // i番目の要素そのものにアクセスできる機能をつけました
    int operator[](int i) { return node[i + n]; } 
    void set(int i, int x) {
        node[i += n] = x;
        while (i >>= 1) node[i] = node[i<<1|0] + node[i<<1|1]; // 和になりました
    }
    int fold(int l, int r) {
        int res = 0; // 初期値が0になりました
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) res = res + node[l++]; // 和になりました
            if (r & 1) res = node[--r] + res; // 和になりました
        }
        return res;
    }
};

int main() {
    int n, q;
    cin >> n >> q;
    segment_tree seg(n);
    while (q--) {
        int com, x, y;
        cin >> com >> x >> y;
        x--;
        if (com) cout << seg.fold(x, y) << '\n';
        else seg.set(x, seg[x] + y); // 元々の値にyを足した値で更新します
    }
}

解けました。
(これはどうでもよいのですが、なぜRMQは0-indexedで、RSQは1-indexedなのでしょうか…)
RMQから変化した部分にコメントを付けましたが、全体的なコードの見た目はほとんど変わらないことが分かるかと思います。

抽象化しよう

ただ、クエリの種類に応じてセグ木の中身をちまちま書き換えるのは面倒ではありませんか? そこで有用なのが「抽象化」です。

「抽象化」とは「具体性を捨て、共通の概念を抜き出すこと」だと私は理解しています。

上の2つの問題を例にとって、セグ木の抽象化について考えましょう。
RMQは「区間の最小値を求める」クエリ
RSQは「区間の和を求める」クエリでしたが、
区間の〇〇を求める」という部分は共通しています。
まさにこの共通部分の抜き出しこそ、抽象化です。

それでは、

  1. 〇〇に相当する、区間の結合を行う演算
  2. minでいえば \infty (RMQのコードではINT_MAX)、加法でいえば 0 のような、
    他のどんな値とその演算を行っても、相手を変化させないような値(数学の言葉でいうと「単位元」というもの)

の2つをセグ木に渡してあげることで、「区間の〇〇を求める」ことができる汎用的な設計のセグ木へと進化させましょう。

上記の私なりの表現は、数学的な厳密性には欠けた記述ですので、数学的にどのような性質が成り立っているものならばセグ木に載るのかということに関しては、beetさんのブログをお読みください。
セグメント木について - beet's soil

また、演算の渡し方には以下のnoshi91さんのブログに書かれている通り様々な設計がありますが、今回は私の主観で一番理解しやすいと思われる、「A: std::function で関数オブジェクトを持つ」書き方で実装しました。ここは個人の趣味・趣向に応じて適宜改変してください。
代数的構造を乗せるデータ構造の設計について - noshi91のメモ

それでは、抽象化セグ木で、最初に解いたRMQを改めて解き直したコードがこちら。

#include <iostream>
#include <vector>
#include <climits>
#include <functional>

using namespace std;

template<typename T>
struct segment_tree {
    using F = function<T(T, T)>; // 型Tの値を2つ受け取って型Tの値を返すような関数型

    int n;
    vector<T> node;
    F combine; // 区間の結合を行う演算
    T identity; // 単位元

    segment_tree(int n, F combine, T identity)
        : n(n), node(n<<1, identity), combine(combine), identity(identity) {}

    T operator[](int i) { return node[i + n]; }

    void set(int i, T x) {
        node[i += n] = x;
        // combine関数により、2つの子ノードを結合した結果を親ノードに記録
        while (i >>= 1) node[i] = combine(node[i<<1|0], node[i<<1|1]); 

    }

    T fold(int l, int r) {
        T res = identity; // 初期値は単位元
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            // 区間内のものを左右から結合
            if (l & 1) res = combine(res, node[l++]); 
            if (r & 1) res = combine(node[--r], res);
        }
        return res;
    }
};


int main() {
    int n, q;
    cin >> n >> q;
    // ラムダ式でcombine関数を定義
    auto combine = [](int a, int b) { return min(a, b); };
    segment_tree<int> seg(n, combine, INT_MAX); // 要素数、演算、単位元をセグ木に渡す
    while (q--) {
        int com, x, y;
        cin >> com >> x >> y;
        if (com) cout << seg.fold(x, y+1) << '\n';
        else seg.set(x, y);
    }
}

RSQを解いたコードがこちら。
(セグ木部分は上と全く同じなので省略)

int main() {
    int n, q;
    cin >> n >> q;
    auto combine = [](int a, int b) { return a + b; }; // 変わったのはここと
    segment_tree<int> seg(n, combine, 0); // ここくらい
    while (q--) {
        int com, x, y;
        cin >> com >> x >> y;
        x--;
        if (com) cout << seg.fold(x, y) << '\n';
        else seg.set(x, seg[x] + y);
    }
}

抽象化をすることによって、演算が変わってもセグ木内部のコードを書き換える必要がなくなります。
抽象化、美しくありませんか?私は魅了されてしまいました。

まだセグ木の抽象化をしたことのない方、なんとなく敬遠されていた方が、
この記事を読んで「思っていたより簡単だな」と感じていただければ嬉しいです。

最後の抽象化非再帰セグ木をそのまま使用されてももちろん構わないのですが、私個人としては未完成なセグ木だと考えています。

読者への課題として

  • \Theta(N) でのセグ木の初期化
  • foldで可換則を要求しない
  • O(\log(N)) のセグ木上の二分探索を実装

などをあえて書かずに残しておきました。ぜひチャレンジしてみてくださいね。