You are given a 0-indexed array of positive integers w
where w[i]
describes the weight of the i<sup>th</sup>
index. You need to implement the function pickIndex()
, which randomly picks an index in the range [0, w.length - 1]
(inclusive) and returns it. The probability of picking an index i
is w[i] / sum(w)
. For example, if w = [1, 3]
, the probability of picking index 0
is 1 / (1 + 3) = 0.25
(i.e., 25%
), and the probability of picking index 1
is 3 / (1 + 3) = 0.75
(i.e., 75%
). Implement the Solution
class with the following methods:
Solution(int[] w)
initializes the object with the array of weights w
.int pickIndex()
returns a randomly picked index based on the weights.# Weighted Random Index Selection
This problem requires implementing a `pickIndex()` function that randomly selects an index from an array of weights, where the probability of selecting each index is proportional to its weight. We can solve this by using a prefix sum array and binary search.
## Naive Approach (Not Efficient)
A naive approach would involve generating a large array where each index `i` appears `w[i]` times. Then, we can randomly select an index from this large array. However, this is extremely inefficient due to the large space requirements when dealing with large weights.
## Optimal Approach: Prefix Sum and Binary Search
A better approach is to use a prefix sum array. The prefix sum array `prefix_sums` will store the cumulative sum of the weights up to each index. To pick an index, we generate a random number between 1 and the total sum of weights. Then, we perform a binary search on the prefix sum array to find the index where the random number falls within the cumulative weight.
### Code Implementation (Python)
```python
import random
class Solution:
def __init__(self, w: list[int]):
self.prefix_sums = []
self.total_sum = 0
for weight in w:
self.total_sum += weight
self.prefix_sums.append(self.total_sum)
def pickIndex(self) -> int:
target = random.randint(1, self.total_sum)
low, high = 0, len(self.prefix_sums) - 1
while low <= high:
mid = (low + high) // 2
if self.prefix_sums[mid] < target:
low = mid + 1
else:
high = mid - 1
return low
# Example Usage:
w = [1, 3]
solution = Solution(w)
print(solution.pickIndex())
print(solution.pickIndex())
print(solution.pickIndex())
Let's say w = [1, 3]
.
prefix_sums
becomes [1, 4]
and total_sum
becomes 4.Consider the weights as sections of a number line where the length of each section is proportional to the weight. The prefix sum essentially marks the end of each section.
[1, 3]
Number Line:
0--1--2--3--4
|
Weight 1 (Index 0)
--------|
Weight 3 (Index 1)
Prefix Sums: [1, 4]
__init__
method iterates through the input array w
once to compute the prefix sums. Therefore, the runtime complexity is O(n), where n is the length of w
.pickIndex
method performs a binary search on the prefix sums array. Binary search has a runtime complexity of O(log n).Therefore, the overall runtime complexity is:
__init__
: O(n)pickIndex
: O(log n)__init__
method creates a prefix_sums
array of the same length as the input array w
. Therefore, the space complexity is O(n), where n is the length of w
.pickIndex
method uses a constant amount of extra space, so its space complexity is O(1).Therefore, the overall space complexity is O(n).
w
is empty, the code will raise an exception because we try to iterate through an empty array. To handle this case, we could add a check at the beginning of the __init__
method to raise an exception or return if the array is empty.pickIndex
method will still work correctly because the random number is generated between 1 and the total sum, and binary search will find the appropriate index.total_sum
could potentially exceed the maximum integer value. To handle this, we could use long integers or normalize the weights by dividing all weights by a common factor.pickIndex()
will always return 0.import random
class Solution:
def __init__(self, w: list[int]):
if not w:
raise ValueError("Input array cannot be empty")
self.prefix_sums = []
self.total_sum = 0
for weight in w:
if weight < 0: # Weights should be non-negative
raise ValueError("Weights must be non-negative")
self.total_sum += weight
self.prefix_sums.append(self.total_sum)
def pickIndex(self) -> int:
if self.total_sum == 0: # If all weights are zero
return random.randint(0, len(self.prefix_sums) - 1)
target = random.randint(1, self.total_sum)
low, high = 0, len(self.prefix_sums) - 1
while low <= high:
mid = (low + high) // 2
if self.prefix_sums[mid] < target:
low = mid + 1
else:
high = mid - 1
return low