Random Pick With Weight

// https://leetcode.com/problems/random-pick-with-weight/

/*
You are given a 0-indexed array of positive integers w where w[i] describes the weight of the ith index.
You need to implement the function pickIndex(), which randomly picks an index in the range [0, w.length - 1] (inclusive) and returns it. The probability of picking an index i is w[i] / sum(w).
For example, if w = [1, 3], the probability of picking index 0 is 1 / (1 + 3) = 0.25 (i.e., 25%), and the probability of picking index 1 is 3 / (1 + 3) = 0.75 (i.e., 75%).

Ex1:
Input
["Solution","pickIndex"]
[[[1]],[]]
Output
[null,0]

Explanation
Solution solution = new Solution([1]);
solution.pickIndex(); // return 0. The only option is to return 0 since there is only one element in w.

Ex2:
Input
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output
[null,1,1,1,1,0]

Explanation
Solution solution = new Solution([1, 3]);
solution.pickIndex(); // return 1. It is returning the second element (index = 1) that has a probability of 3/4.
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 0. It is returning the first element (index = 0) that has a probability of 1/4.

Since this is a randomization problem, multiple answers are allowed.
All of the following outputs can be considered correct:
[null,1,1,1,1,0]
[null,1,1,1,1,1]
[null,1,1,1,0,0]
[null,1,1,1,0,1]
[null,1,0,1,0,0]
......
and so on.
*/

#include <iostream>
#include <vector>
#include <random>
#include <algorithm>
#include <numeric>

using namespace std;

class Solution {
public:
    Solution(vector<int>& w) {
        int prefixSum = 0;
        for (int weight : w) {
            prefixSum += weight;
            prefix_sum.push_back(prefixSum);
        }
        // total sum
        total_sum = prefix_sum.back();
    }
    
    // Time: O(logn), Space: O(n)
    int pickIndex() {
        // random number
        int random_num = rand() % total_sum;
        // binary search
        return upper_bound(prefix_sum.begin(), prefix_sum.end(), random_num) - prefix_sum.begin();
    }

private:
    vector<int> prefix_sum;
    int total_sum;
};

// Follow up 2: For each index we pick, we delete it from the array.
// Time: O(logn), Space: O(n)
// Segment Tree

class Solution2 {
public:
    Solution2(vector<int>& w) {
        this->arr = w;
        int n = w.size();
        segTree.resize(4 * n);
        build(1, 0, n - 1);
    }

    void build(int node, int start, int end) {
        if (start == end) {
            segTree[node] = arr[start];
            return;
        }

        int mid = (start + end) / 2;
        build(2 * node, start, mid);
        build(2 * node + 1, mid + 1, end);
        segTree[node] = segTree[2 * node] + segTree[2 * node + 1];
    }

    void update(int node, int start, int end, int idx, int val) {
        if (start == end) {
            arr[idx] += val;
            segTree[node] += val;
            return;
        }

        int mid = (start + end) / 2;
        if (idx <= mid)
            update(2 * node, start, mid, idx, val);
        else
            update(2 * node + 1, mid + 1, end, idx, val);

        segTree[node] = segTree[2 * node] + segTree[2 * node + 1];
    }

    void query(int node, int start, int end, int& val, int& res) {
        cout << "start: " << start << " end: " << end << " sum: " << segTree[node] << endl;
        // If out of boundary, return best index found so far.
        if (start > end || segTree[node] == 0) {
            return;
        }

        // If it's a leaf node, check if it's the best index found so far.
        if (start == end) {
            if (segTree[node] <= val && arr[start] != 0) {
                res = max(res, start);
            }
            return;
        }

        int mid = (start + end) / 2;
        if (arr[mid] != 0) {
            res = max(res, mid);
        }

        // If the left child has a sum less than or equal to val, 
        // then we might need to go to the right child.
        if (segTree[2 * node+1] >= val) {
            query(2 * node + 1, mid + 1, end, val, res);
        }
        // Else, continue searching in the left child.
        else {
            query(2 * node, start, mid, val, res);
        }
    }

    int pickIndex() {
        if (segTree[1] == 0) {
            return -1;
        }

        // random number
        int random_num = rand() % segTree[1];
        // query using segment tree.
        int idx = -1;
        query(1, 0, arr.size()-1, random_num, idx);
        // update segment tree along with original array.
        int delta = arr[idx];
        update(1, 0, arr.size()-1, idx, -delta);
        print();
        return idx;
    }

    void print() {
        cout << "segment tree" << endl;
        for (int i = 0; i < segTree.size(); ++i)
            cout << segTree[i] << " ";
        cout << endl;

        cout << "array " << endl;
        for (int i = 0; i < arr.size(); i++) {
            cout << arr[i] << " ";
        }
        cout << endl;
    }

    vector<int> segTree;
    vector<int> arr;
};

int main() {
    vector<int> w = {1, 3, 5, 7, 9, 11};
    // Solution solution(w);
    // cout << solution.pickIndex() << endl;
    // cout << solution.pickIndex() << endl;
    // cout << solution.pickIndex() << endl;
    // cout << solution.pickIndex() << endl;
    // cout << solution.pickIndex() << endl;
    // cout << solution.pickIndex() << endl;

    Solution2 segSolution(w);
    cout << segSolution.pickIndex() << endl;
    cout << segSolution.pickIndex() << endl;
    cout << segSolution.pickIndex() << endl;
    cout << segSolution.pickIndex() << endl;
    cout << segSolution.pickIndex() << endl;
    cout << segSolution.pickIndex() << endl;
    // cout << segSolution.pickIndex() << endl;
    return 0;
}```

Last updated