プログラムのお勉強メモ

プログラムの勉強メモです. Python, Rust, など.

Rustでセグメント木を作ってみた

セグメント木

  • 競プロer御用達のデータ構造の1つ
    • 配列の範囲に対して値を代入したり, 範囲内の合計値や最大値/最小値等を 高速に 取得するためのデータ構造
    • AtCoder Beginner Contest 185 にてセグ木の問題が出題された
    • 悔しいのでRustで実装することとした

実装にあたり心がけたこと

  • せっかくなので本で見たジェネリクスを使う
  • 問題は範囲内の XOR だったが、計算式を差し替えられるように作りたい
  • 他の人の解答を パクる 参考にさせて頂きながら, 自分の好みの実装を理解しながら書く
  • 相変わらず実装にあたっては書籍*1, *2を大いに活用させて頂いた

実装

struct SegmentTree<F, T> {
    size: usize,
    tree: Vec<T>,
    element: T,
    eval: F,
}

impl<F: Fn(T, T) -> T, T: Copy + Eq + std::fmt::Debug> SegmentTree<F, T> {
    fn new(max: usize, element: T, eval: F) -> Self {
        // サイズを収まる範囲の 2^x乗 にする
        let size = max.next_power_of_two();
        Self {
            size,
            tree: vec![element; size * 2], // セグ木はその2倍のサイズ
            element,
            eval,
        }
    }

    // 開閉区間のの値を取得する
    // new した時のロジックで処理される
    fn get(&self, left: usize, right: usize) -> T {
        return self._get(left, right + 1, 1, 0, self.size);
    }

    fn _get(&self, left: usize, right: usize, now_pos: usize, l: usize, r: usize) -> T {
        // 捜索範囲を超えた場合 初期値 を返す
        if r <= left || right <= l {
            self.element
        // 探索終了条件
        // 二分探索して値を見つけた場合
        } else if left <= l && r <= right {
            self.tree[now_pos]
        // 探索が続く場合
        // 今のポジションから左(*2)と右(*2+1)に移動
        // 左に移動した場合、右端をずらす
        // 右に移動した場合、左端をずらす
        } else {
            (self.eval)(
                self._get(left, right, now_pos * 2, l, (l + r) / 2), // 左
                self._get(left, right, now_pos * 2 + 1, (l + r) / 2, r), //右
            )
        }
    }

    pub fn update(&mut self, index: usize, value: T) {
        let mut i = self.size + index;
        while i != 0 {
            let before = self.tree[i];
            let after = (self.eval)(before, value);

            // 更新しても変わらない場合その後も変わらない
            if before == after {
                break;
            }
            self.tree[i] = after;
            i /= 2;
        }
    }
}

impl<F, T: std::fmt::Debug> std::fmt::Debug for SegmentTree<F, T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "SegmentTree{:?}", self.tree)
    }
}

// 以下, AtCoder Beginner Contest 185-F を解くためのコード
fn main() {
    let (n, q) = input!(usize, usize);
    let mut seg_tree = SegmentTree::new(n, 0, |a, b| a ^ b);
    let a_vec = input!(isize; "vec");

    for (i, v) in a_vec.iter().enumerate() {
        seg_tree.update(i, *v);
    }

    //println!("{:?}", seg_tree);

    for _ in 0..q {
        let (t, x, y) = input!(usize, usize, isize);

        match t == 1 {
            true => seg_tree.update(x - 1, y),
            false => println!("{}", seg_tree.get(x - 1, (y - 1) as usize)),
        }
    }
}

苦労した点

  • next_power_of_two() 関数の存在をしれただけでも実施した甲斐があった
  • やっぱり所有権が難しい!全然コンパイル通らなかった!
  • ジェネリクス型もわかったようでわからない
  • Rust is very difficult.
  • セグメント木の範囲探索は二分探索法を利用
    • 下から2倍 or +1/-1 しながら挟み込む方法も見たが二分探索のほうが好みだった

参考書籍

  1. 実践Rustプログラミング入門
  2. 問題解決力を鍛える!アルゴリズムとデータ構造
  3. AtCoder Rust提出者の皆様

*1:実践Rustプログラミング入門

*2:問題解決力を鍛える!アルゴリズムとデータ構造