1856. Maximum Subarray Min-Product

same idea as 907

but presum part...I don't know why should use a n length (without presum[0]) style

and becareful when index == -1 presum[-1] means no sum => so it's 0

T: O(n)

S: O(n)

class Solution {
    public int maxSumMinProduct(int[] nums) {
        int n = nums.length;
        long[] presum = new long[n];
        presum[0] = nums[0];
        for (int i = 1; i < n; i++) {
            presum[i] = presum[i-1] + nums[i];
        }
        System.out.println(Arrays.toString(presum));
        
        long mod = 1000000007;
        Deque<Integer> stack = new ArrayDeque<>();
        int[] nextSmaller = new int[n];
        int[] prevSmaller = new int[n];
        
        Arrays.fill(nextSmaller, n); // important
        Arrays.fill(prevSmaller, -1); // important
        
        // 1 4 7 2
        for (int i = 0; i < n; i++) {
            while (!stack.isEmpty() && nums[stack.peek()] > nums[i]) { // when I found stack peek > new number, find smaller number! start to record and pop
                nextSmaller[stack.peek()] = i;
                stack.pop();
            }
            stack.push(i);
        }
        // System.out.println(Arrays.toString(nextSmaller));
        stack.clear();
        
        for (int i = n - 1; i >= 0; i--) {
            while (!stack.isEmpty() && nums[stack.peek()] >= nums[i]) { //ๅŽป้‡็š„็ด„ๅฎš
                prevSmaller[stack.peek()] = i;
                stack.pop();
            }
            stack.push(i);
        }
        // System.out.println(Arrays.toString(prevSmaller)); 
        long result = Integer.MIN_VALUE;
        for (int i = 0; i < n; i++) {
            long sum = presum[nextSmaller[i]-1] - (prevSmaller[i] == -1 ? 0 : presum[prevSmaller[i]]);
            System.out.println("num[i]:" + nums[i] + ",nextSmaller[i]:" + nextSmaller[i] + " prevSmaller[i]:" + prevSmaller[i] + ", sum=" + sum);
            
            result = Math.max(result, nums[i]*sum);
        }
        return (int)(result%mod);
    }
}
/*
1. use monostack, find nextSmaller, prevSmaller -> find min subarray range index array first
get each number as a min in a subarray's range

1 as min, => [4, 4, 3, 4]

2 3

3 2 [x,x, 3, x]
2 1
1 0

1
2 
 [-1,0,1,-1]
2


next[i] 4 - 0
prev0-(-1) = 1


sum[0,3] = presum[3+1] - presum[0]
=presum[next[i]]-[0]

2. then use presum -> cal subarray sum = presum[j+1] - presum[i]

 0 1 2 3 4 5 6 7
[2,5,4,2,4,5,3,1,2,4
*/

aws oa

https://leetcode.com/discuss/interview-question/1736639/Solution-to-Amazon-OA-2022-problem-Sum-of-Scores-of-Subarray

[2,3,2,1]

stdout

[3, 2, 3, 4] next

[-1, 0, 0, -1] prev

2 as min -> [2,3,2], sum = 2, 5(2,3), 7(2,3,2), 5(3,2), 2

class Solution {
    public int maxSumMinProduct(int[] arr) {
    
        int n = arr.length;
        Stack<Integer> stack = new Stack();
        long ans = 0;
        long[] sum = new long[n+1], sum1 = new long[n+1];//prefix sum and suffix sum
        long[] prefix = new long[n+1], suffix = new long[n+1];//prefix sum of prefixsum and suffix sum of suffixsum
        long mod = 1_000_000_007;
        for(int i = 0; i < n; i++)
        {
            sum[i+1] = (sum[i]+arr[i])%mod;
            prefix[i+1] = (prefix[i]+sum[i+1])%mod;
        }
        for(int i = n-1; i >= 0; i--)
        {
            sum1[i] = (sum1[i+1]+arr[i])%mod;
            suffix[i] = (suffix[i+1]+sum1[i])%mod;
        }
        
        for(int i = 0; i < n; i++)
        {
            while(!stack.isEmpty() && arr[stack.peek()] >= arr[i])
            {
                int cur =  stack.pop();
                int prev = stack.isEmpty() ? -1 : stack.peek();
                int next = i;
                //the commented lines shows how the left sum and right sum are calculated
                //Maybe a little hard to understand but the idea is similar
                //long lsum = suffix[prev+1]-suffix[cur]-sum1[cur]*(cur-prev-1)%mod;
                //long rsum = prefix[next]-prefix[cur+1]-sum[cur+1]*(next-cur-1)%mod;
                //below I takes modulo everytime there is multiplication and addition to avoid overflow in java
                long lsum = (mod+suffix[prev+1]-(suffix[cur]+sum1[cur]*(cur-prev-1)%mod)%mod)%mod;
                long rsum = (mod+prefix[next]-(prefix[cur+1]+sum[cur+1]*(next-cur-1)%mod)%mod)%mod;
                long self = ((long)arr[cur]*(next-cur)%mod)*(cur-prev)%mod;
                long curres = (long)arr[cur]*((rsum*(cur-prev)%mod+lsum*(next-cur)%mod+self)%mod)%mod;
                ans =(ans+curres)%mod;
            }
            stack.push(i);
        }
        while(!stack.isEmpty())//do the same thing for the remaining numbers
        {
            int cur = stack.pop();
            int prev = stack.isEmpty() ? -1 : stack.peek();
            int next = n;
            //comments are the same as previously
            long lsum = (mod+suffix[prev+1]-(suffix[cur]+sum1[cur]*(cur-prev-1))%mod)%mod;
            long rsum = (mod+prefix[next]-(prefix[cur+1]+sum[cur+1]*(next-cur-1))%mod)%mod;
            long self = ((long)arr[cur]*(next-cur)%mod)*(cur-prev)%mod;
            long curres = (long)arr[cur]*((rsum*(cur-prev)%mod+lsum*(next-cur)%mod+self)%mod)%mod;
            ans =(ans+curres)%mod;
        }
        return (int)ans;
    }
}

Last updated