Delete Node In BST

// https://leetcode.com/problems/delete-node-in-a-bst

/* 
Given a root node reference of a BST and a key, delete the node with the given key in the BST. 
Return the root node reference (possibly updated) of the BST.
Basically, the deletion can be divided into two stages:

1. Search for a node to remove.
2. If the node is found, delete the node.

Ex1:
Input: root = [5,3,6,2,4,null,7], key = 3
Output: [5,4,6,2,null,null,7]
Explanation: Given key to delete is 3. So we find the node with value 3 and delete it.
One valid answer is [5,4,6,2,null,null,7], shown in the above BST.
Please notice that another valid answer is [5,2,6,null,4,null,7] and it's also accepted.

Ex2:
Input: root = [5,3,6,2,4,null,7], key = 0
Output: [5,3,6,2,4,null,7]
Explanation: The tree does not contain a node with value = 0.

Ex3:
Input: root = [], key = 0
Output: []

*/

#include <cassert>
#include <iostream>

using namespace std;

class TreeNode {
public:
    int val;
    TreeNode* left;
    TreeNode* right;
    TreeNode(int v) : val(v), left(nullptr), right(nullptr) {}
};

int successor(TreeNode* root) {
    root = root->right;
    while (root->left) {
        root = root->left;
    }

    return root->val;
}
int predecessor(TreeNode* root) {
    root = root->left;
    while (root->right) {
        root = root->right;
    }

    return root->val;
}

// Time complexity: O(H), Space complexity: O(H)
// H is the height of the tree, which is O(logN) for balanced tree, O(N) for unbalanced tree
TreeNode* deleteNode(TreeNode* root, int key) {
    if (!root) {
        return nullptr;
    }

    if (root->val == key) {
        if (!root->left && !root->right) {
            return nullptr;
        } else if (root->right) {
            root->val = successor(root);
            root->right = deleteNode(root->right, root->val);
        } else {
            root->val = predecessor(root);
            root->left = deleteNode(root->left, root->val);
        }
    } else if (root->val > key) {
        root->left = deleteNode(root->left, key);
    } else {
        root->right = deleteNode(root->right, key);
    }

    return root;
}

void printTree(TreeNode* root) {
    if (!root) {
        return;
    }

    printTree(root->left);
    cout << root->val << " ";
    printTree(root->right);
}

int main() {
    TreeNode* root1 = new TreeNode(5);
    root1->left = new TreeNode(3);
    root1->right = new TreeNode(6);
    root1->left->left = new TreeNode(2);
    root1->left->right = new TreeNode(4);
    root1->right->right = new TreeNode(7);
    TreeNode* res1 = deleteNode(root1, 3);
    printTree(res1);
    assert(res1->val == 5);
    assert(res1->left->val == 4);
    assert(res1->right->val == 6);
    assert(res1->left->left->val == 2);
    assert(res1->left->right == nullptr);
    assert(res1->right->right->val == 7);

    TreeNode* root2 = new TreeNode(5);
    root2->left = new TreeNode(3);
    root2->right = new TreeNode(6);
    root2->left->left = new TreeNode(2);
    root2->left->right = new TreeNode(4);
    root2->right->right = new TreeNode(7);
    TreeNode* res2 = deleteNode(root2, 0);
    printTree(res2);
    assert(res2->val == 5);

    TreeNode* root3 = nullptr;
    TreeNode* res3 = deleteNode(root3, 0);
    assert(res3 == nullptr);

    return 0;
}```

Last updated