```java
class Solution {
public int maxSumAfterPartitioning(int[] arr, int k) {
Integer[] memo = new Integer[arr.length];
return dfs(arr, k, memo, 0);
}
private int dfs(int[] arr, int k, Integer[] memo, int start) {
if (start >= arr.length) {
return 0;
}
if (memo[start] != null) {
return memo[start];
}
int max = 0;
int result = 0;
for (int i = start; i < Math.min(start+k, arr.length); i++) {
max = Math.max(arr[i], max);
result = Math.max(result, max*(i - start + 1) + dfs(arr, k, memo, i+1));
}
return memo[start] = result;
}
}
/**
T: O(n*k)
S: O(n)
9, 9, 7, 6
9*4 + 9*4 +
36 + 36 + 21
[1,|4,1,5,7,|3,6,1,9|,9,3]
12 + 28 + 36
8 + 21 + 12 + 36
54
28
1
83
*/
```