```java
class Solution {
TreeNode first;
TreeNode second;
TreeNode pre = new TreeNode(Integer.MIN_VALUE);
public void recoverTree(TreeNode root) {
dfs(root);
int temp = first.val;
first.val = second.val;
second.val = temp;
}
private void dfs(TreeNode node) {
if (node == null) {
return;
}
dfs(node.left);
if (pre.val > node.val) { // can't swap here
if (first == null) { // tricky one
first = pre;
}
second = node;
}
pre = node;
dfs(node.right);
}
}
```