Given an integer array arr
, and an integer target
, return the number of tuples i, j, k
such that i < j < k
and arr[i] + arr[j] + arr[k] == target
.
As the answer can be very large, return it modulo 10^9 + 7
.
Example 1:
Input: arr = [1,1,2,2,3,3,4,4,5,5], target = 8
Output: 20
Explanation:
Enumerating by the values (arr[i], arr[j], arr[k]):
(1, 2, 5) occurs 8 times;
(1, 3, 4) occurs 8 times;
(2, 2, 4) occurs 2 times;
(2, 3, 3) occurs 2 times.
Example 2:
Input: arr = [1,1,2,2,2,2], target = 5
Output: 12
Explanation:
arr[i] = 1, arr[j] = arr[k] = 2 occurs 12 times:
We choose one 1 from [1,1] in 2 ways,
and two 2s from [2,2,2,2] in 6 ways.
Example 3:
Input: arr = [2,1,3], target = 6
Output: 1
Explanation: (1, 2, 3) occured one time in the array so we return 1.
Constraints:
3 <= arr.length <= 3000
0 <= arr[i] <= 100
0 <= target <= 300
Let's walk through a solution to the 3Sum Multiplicity problem.
The most straightforward approach is to use three nested loops to iterate through all possible combinations of i
, j
, and k
such that i < j < k
. For each combination, we check if arr[i] + arr[j] + arr[k]
equals the target
. If it does, we increment our count. Finally, we return the count modulo 10^9 + 7
.
def threeSumMultiplicity_naive(arr, target):
n = len(arr)
count = 0
for i in range(n):
for j in range(i + 1, n):
for k in range(j + 1, n):
if arr[i] + arr[j] + arr[k] == target:
count += 1
return count % (10**9 + 7)
Time Complexity: O(n^3), due to the three nested loops. Space Complexity: O(1), as we only use a constant amount of extra space.
We can optimize this solution by using a counting approach. We can use a hash map (or in this case, since the constraints limit arr[i]
to be between 0 and 100, an array) to store the frequency of each number in arr
. Then, we can iterate through all possible pairs of numbers and check if the third number needed to reach the target exists in our frequency map.
def threeSumMultiplicity(arr, target):
MOD = 10**9 + 7
count = 0
freq = {}
for num in arr:
freq[num] = freq.get(num, 0) + 1
nums = sorted(freq.keys())
for i in range(len(nums)):
a = nums[i]
for j in range(i, len(nums)):
b = nums[j]
c = target - a - b
if c in freq and c >= b:
if a == b == c:
count = (count + freq[a] * (freq[a] - 1) * (freq[a] - 2) // 6) % MOD
elif a == b != c:
count = (count + freq[a] * (freq[a] - 1) // 2 * freq[c]) % MOD
elif a != b == c:
count = (count + freq[a] * freq[b] * (freq[b] - 1) // 2) % MOD
elif a < b < c:
count = (count + freq[a] * freq[b] * freq[c]) % MOD
return count
Time Complexity: O(n + k^2), where n is the length of the array and k is the number of distinct elements in the array. The n
comes from counting the frequencies. The k^2
comes from iterating through the possible pairs. Because k
is bounded by 101 in this problem, we consider the runtime to be closer to O(n).
Space Complexity: O(k), where k is the number of distinct elements in the array, which is the space used to store the frequency map.
10^9 + 7
to avoid integer overflow.