Design a data structure to manage a set of intervals and efficiently count the number of integers covered by these intervals. Implement the CountIntervals
class with the following methods:
CountIntervals()
: Initializes an empty set of intervals.void add(int left, int right)
: Adds the interval [left, right]
to the set.int count()
: Returns the number of integers present in at least one interval.Note: An interval [left, right]
includes all integers x
where left <= x <= right
.
Example:
CountIntervals countIntervals = new CountIntervals();
countIntervals.add(2, 3); // Adds interval [2, 3]
countIntervals.add(7, 10); // Adds interval [7, 10]
countIntervals.count(); // Returns 6 (integers 2, 3, 7, 8, 9, 10)
countIntervals.add(5, 8); // Adds interval [5, 8]
countIntervals.count(); // Returns 8 (integers 2, 3, 5, 6, 7, 8, 9, 10)
Explain the time and space complexity of your solution. Consider edge cases such as overlapping intervals and large input ranges (1 <= left <= right <= 10^9). At most 10^5 calls to add
and count
will be made.
A naive solution would be to store all the intervals in a list and, for the count
operation, iterate through a large range (from the minimum start to the maximum end across all intervals) and check for each integer whether it falls within any of the intervals. This approach is simple to implement but highly inefficient, especially when intervals are large or the number of intervals increases.
class CountIntervals:
def __init__(self):
self.intervals = []
def add(self, left: int, right: int) -> None:
self.intervals.append((left, right))
def count(self) -> int:
counted = set()
for left, right in self.intervals:
for i in range(left, right + 1):
counted.add(i)
return len(counted)
add
: O(1)count
: O(N*K), where N is the number of intervals and K is the average length of the intervals.This is a bad solution. We can do better.
An optimal solution uses a segment tree to efficiently manage the intervals. The segment tree allows us to merge overlapping intervals and quickly calculate the number of covered integers. Each node in the tree represents an interval, and the node stores the total count of integers covered by the intervals within that range. With segment tree, we can achieve logarithmic time for both the add
and count
operations.
Segment Tree Structure: Each node represents an interval [start, end]
. It stores:
start
: Start of the interval.end
: End of the interval.count
: Number of integers covered by the intervals in this range.left
, right
: Child nodes representing sub-intervals [start, mid]
and [mid+1, end]
.Add Operation: When adding an interval [left, right]
, traverse the tree to find the nodes that overlap with the new interval. Update the count
in each affected node.
Count Operation: The count
operation simply returns the count
stored at the root of the segment tree, as it represents the total number of integers covered by the intervals.
class CountIntervals:
def __init__(self):
self.root = None
class Node:
def __init__(self, start, end):
self.start = start
self.end = end
self.count = 0
self.left = None
self.right = None
def add(self, left: int, right: int) -> None:
def update(node, start, end):
if not node:
return CountIntervals.Node(start, end)
if node.start > end or node.end < start:
return node
if start <= node.start and node.end <= end:
node.count = node.end - node.start + 1
return node
if node.left is None:
node.left = CountIntervals.Node(node.start, (node.start + node.end) // 2)
if node.right is None:
node.right = CountIntervals.Node((node.start + node.end) // 2 + 1, node.end)
node.left = update(node.left, start, end)
node.right = update(node.right, start, end)
node.count = node.left.count + node.right.count
return node
if self.root is None:
self.root = CountIntervals.Node(left, right)
self.root = update(self.root, left, right)
def count(self) -> int:
return self.root.count if self.root else 0
add
: O(log N), where N is the range of possible values.count
: O(1)left > right
appropriately (e.g., by ignoring such intervals).