Balancing an AVL tree

Background - I got through college without having to learn this - I then thought AVL trees were an abomination.

ok. AVL trees.

What - AVL Trees are self balancing binary search trees
Binary Search Tree ? - A search tree in which for any given node, nodes to its left are smaller than it and those on the right are larger. This makes it easy to search.

The problem - What if a BST is skewed one side ? Let's say elements are inserted in the following order 2, 3, 4, 5, 6, 7. The tree one would get is:

  2
   \
   3
    \
     4
      \
       5
        \
         6
          \
           7

In such a tree, it would take 6 steps to find 7 - which is O(n) worst time. Only if we were to get a tree that would resemble:

    5
   /\
  3  6
 / \  \
2   4  7

It would take only 3 comparisions.

In a perfectly balanced binary search tree, it would take O(log n) steps to find - say - the furthest element from the root. This is because, for every node, the weight (number of nodes) of the left and the right subtrees is the same - which causes the number of elements to be searched to half at every node. If you know about binary search in a sorted array, this is exactly the same logic.

The solution - At every insert into the BST, make sure the tree is roughly balanced. More precisely, the height of the left subtree and height of the right subtree can differ by at most 1. If any insert causes this invariant to fail, we need one or more 'rotations' to re-balance the tree.

The balance factor of any node n is calculated as:
  balance_factor = height(n->left) - height(n->right)

For an AVL tree balance factor belongs to the set {-1, 0, 1}.

We define height as:
  height(n) = if (!n) then -1
              else max(height(n->left), height(n->right)) + 1

When we do 1 insert it might cause one of the subtrees to be heavier by 1. Therefore balance factor will be one of {-2, -1, 0, 1, 2} at any node. Thus we say the tree requires rebalancing

Say the tree we have is:

  2
   \
    5

We insert 6

  2
   \
    5
     \
      6

Precisely:

  insert(root, val) =
    if(!root) root = new_node(val);
 
    if(val > root)
      if(root->right) insert(root->right, val)
      else root->right = new_node(val)
    else
      if(root->left) insert(root->left, val)
      else root->left = new_node(val)

But, this causes the tree to be imbalanced at (2). The balance factor at (2) is calculated as:

  height(2->left) = -1
  height(2->right) = 1
  balance_factor = -2

We balance the tree by rotating it from right to left:

  5
 / \
2   6

We call this a right right rotation.

As you see, the right child of the current root is rotated to become the new root and the old root now becomes the left child of the new root.

So to rotate a tree at root:
  oldroot = root
  root = root->right
  root->left = oldroot

However, if the tree was instead:

  2
   \
    5
   / \
  4   6

We'd rotate this to:

   5
  / \
 2  6
  \
   4

Therefore, we need an extra movement to move the left child of the to-be root (5) to be moved to become the right child of the older root (2).We'd modify the above rotation to:

right_right_rotation(root) =

  oldroot = root
  root = root->right
  temp = root->left
  root->left = oldroot
  oldroot->right = temp

  root->left->height = height(root->left)
  root->height = height(root)

Note that we re-calculate the height of the roots left child (old root) and the root ( new root )

Let's look at a slightly different scenario where the tree is still right heavy but it is towards the left of the right child unlike right->right that we saw earlier.

  2
   \
    6
   /
  4 
 /
3


To balance this, we first rotate from left to right the tree rooted at 6 to get:

  2
  \
   4
  / \
 3   6

Let's think of the subtree rooted at 6 as root. Then

  oldroot = root
  root = root->left
  root->right = oldroot

The shape of it is exactly like the one we balanced earlier through a right right rotation. If we add a left child to node (4) and a right child to node (6) the rotation would be:

from:

   2
   \
    6
   / \
  4   7
 / \
3   5

to:

  2
  \
   4
  / \
 3   6
    / \
   5   7

Therefore to balance such cases, we do:
  1. right_left_rotation
  2. then right_right_rotation

where right_left_rotation(root) =
  temp1 = root->right
  root->right->left = root->right
  temp2 = root->right->right
  root->right->right = temp
  root->right->right->left = temp2

  root->right->right->height = height(root->right->right)
  root->right->height = height(root->right)

Note that we have to recalculate the heights at each rotation like earlier.

We, so far, talked about a right - heavy tree. In case of left heavy tree, we have mirror image scenarios and rotations. We call them left_right_rotation and left_left_rotation


left_left_rotation(root) =


left_right_rotation(root) =

We have only talked about a simple tree where the tree was shallow and violation was at only one level. What if this is a pretty deep tree and there are violations at multiple levels ?

Answer - It's easy. Recursion

We recursively add element to a binary search tree. At every level where we go deeper, we need to check the balance factor and if the tree is imbalanced, we balance it the way up until the root.



Code (C++)



/* Node is defined as :

typedef struct node

{

    int val;

    struct node* left;

    struct node* right;

    int ht;

} node; */



node *right_left_rotation(node *root);

node *right_right_rotation(node *root);

node *left_right_rotation(node *root);

node *left_left_rotation(node *root);

int height(node *root);

int balance_factor(node *node);



node *new_node(int val)

{

    node *n = new node;

    n->val = val;

    n->ht = 0;

    n->left = NULL;

    n->right = NULL;

    return n;

}



node *insert(node *root, int val)

{

    if(!root)

    {

        root = new_node(val);

        return root;

    }

    

    if(val > root->val)

    {

        if(root->right)

        {

            root->right = insert(root->right, val);

            root->ht = height(root);

            if(balance_factor(root) < -1)

            {

                if(balance_factor(root->right) == 1)

                {

                    root = right_left_rotation(root);

                    root = right_right_rotation(root);

                }else

                {

                    root = right_right_rotation(root);

                }

            }

        }else

        {

            root->right = new_node(val);

            root->ht = height(root);

        }

    }else

    {

        if(root->left)

        {

            root->left = insert(root->left, val);

            root->ht = height(root);

            if(balance_factor(root) > 1)

            {

                if(balance_factor(root->left) == -1)

                {

                    root = left_right_rotation(root);

                    root = left_left_rotation(root);

                }else

                {

                    root = left_left_rotation(root);

                }

            }

        }else

        {

            root->left = new_node(val);

            root->ht = height(root);

        }

    } 

    

    return root;

}



node *right_left_rotation(node *root)

{

    node *old_right = root->right;

    root->right = root->right->left;

    node *temp = root->right->right;

    root->right->right = old_right;

    root->right->right->left = temp;

    

    root->right->right->ht = height(root->right->right);

    root->right->ht = height(root->right);

    

    return root;

}



int height(node *root)

{

    if(!root) return -1;
    return max(height(root->left), height(root->right)) + 1;  

}



node *right_right_rotation(node *root)

{

    node *old_root = root;

    root = root->right;

    old_root->right = root->left;

    root->left = old_root;

    

    root->left->ht = height(root->left);

    root->ht = height(root);  

    return root;

}


node *left_right_rotation(node *root)

{

    node *old_left = root->left;

    root->left = root->left->right;

    node *temp = root->left->left;

    root->left->left = old_left;

    root->left->left->right = temp;

    root->left->left->ht = height(root->left->left);

    root->left->ht = height(root->left); 

    return root;

}


node *left_left_rotation(node *root)

{

    node *old_root = root;

    root = root->left;

    old_root->left = root->right;

    root->right = old_root;  

    root->right->ht = height(root->right);

    root->ht = height(root);  

    return root;

}

int balance_factor(node *node)

{

    return height(node->left) - height(node->right);

}

Comments