```java
class Solution {
public int numSquares(int n) {
Integer[] memo = new Integer[n+1];
return dfs(n, memo);
}
private int dfs(int n, Integer[] memo) {
if (n == 0) {
return 0;
}
if (memo[n] != null) {
return memo[n];
}
int count = Integer.MAX_VALUE;
for (int i = 1; i*i <= n; i++) {
count = Math.min(count, dfs(n - i*i, memo) + 1);
}
return memo[n] = count;
}
}
```