# Balanced BS Tree: AVL Tree¶

## Agenda¶

1. Motives
2. "Balanced" binary trees
3. Essential mechanic: rotation
4. Out-of-balance scenarios & rotation recipes
5. Generalized AVL rebalancing (insertion)
6. Rebalancing on removal

## 1. Motives¶

In [1]:
class BSTree:
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right

def __init__(self):
self.size = 0
self.root = None

assert(val not in self)
if not node:
return BSTree.Node(val)
elif val < node.val:
return node
else:
return node
self.size += 1

def __contains__(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)

def __len__(self):
return self.size

def __delitem__(self, val):
assert(val in self)
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
return node
elif val > node.val:
node.right = delitem_rec(node.right)
return node
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
# remove the largest value from the left subtree as a replacement
# for the root value of this tree
t = node.left
if not t.right:
node.val = t.val
node.left = t.left
else:
n = t
while n.right.right:
n = n.right
t = n.right
n.right = t.left
node.val = t.val
return node

self.root = delitem_rec(self.root)
self.size -= 1

def pprint(self, width=64):
"""Attempts to pretty-print this tree's contents."""
height = self.height()
nodes  = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n,level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height-1:
nodes.extend([(None, level+1), (None, level+1)])
repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
elif n:
if n.left or level < height-1:
nodes.append((n.left, level+1))
if n.right or level < height-1:
nodes.append((n.right, level+1))
repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
print(repr_str)

def height(self):
"""Returns the height of the longest branch of the tree."""
def height_rec(t):
if not t:
return 0
else:
return max(1+height_rec(t.left), 1+height_rec(t.right))
return height_rec(self.root)

In [2]:
t = BSTree()
for x in range(6):
t.pprint()

                               0
-                               1
-               -               -               2
-       -       -       -       -       -       -       3
-   -   -   -   -   -   -   -   -   -   -   -   -   -   -   4
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 5

In [3]:
import sys
sys.setrecursionlimit(100)

t = BSTree()
for x in range(100):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-40551a242d9c> in <module>()
4 t = BSTree()
5 for x in range(100):

22                 return node
24         self.size += 1
25

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

19                 return node
20             else:
22                 return node

15             if not node:
---> 16                 return BSTree.Node(val)
17             elif val < node.val:

RuntimeError: maximum recursion depth exceeded

## 3. Essential mechanic: rotation¶

In [4]:
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right

def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right

assert(val not in self)
if not node:
return AVLTree.Node(val)
elif val < node.val:
return node
else:
return node
self.size += 1

In [5]:
t = AVLTree()
for x in range(6, 0, -1):
t.pprint()

                               6
5                               -
4               -               -               -
3       -       -       -       -       -       -       -
2   -   -   -   -   -   -   -   -   -   -   -   -   -   -   -
1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

In [6]:
t.root.rotate_right()
t.pprint()

                               5
4                               6
3               -               -               -
2       -       -       -       -       -       -       -
1   -   -   -   -   -   -   -   -   -   -   -   -   -   -   -

In [7]:
t.root.rotate_right()
t.pprint()

                               4
3                               5
2               -               -               6
1       -       -       -       -       -       -       -

In [8]:
t.root.left.rotate_right()
t.pprint()

                               4
2                               5
1               3               -               6

In [9]:
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right

def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right

@staticmethod
def height(n):
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))

assert(val not in self)
if not node:
return AVLTree.Node(val)
elif val < node.val:
else:
if AVLTree.Node.height(node.left) > AVLTree.Node.height(node.right)+1:
node.rotate_right()
return node
self.size += 1

In [10]:
val = 50
t = AVLTree()

In [21]:
# (evaluate multiple times with ctrl-enter)
val -= 1
t.pprint()

                               47
43                              49
41              45              48              50
40      42      44      46      -       -       -       -

In [22]:
val = 0
t = AVLTree()

In [29]:
# (evaluate multiple times with ctrl-enter)
val += 1
t.pprint()

                               0
-                               1
-               -               -               2
-       -       -       -       -       -       -       3
-   -   -   -   -   -   -   -   -   -   -   -   -   -   -   4
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 5
---------------------------------------------------------------6


## 4. "Out-of-balance" scenarios & rotation recipes¶

In [30]:
# "left-left" scenario
t = BSTree()
for x in [3, 2, 1]:
t.pprint()

                               3
2                               -
1               -               -               -

In [31]:
# "left-right" scenario
t = BSTree()
for x in [3, 1, 2]:
t.pprint()

                               3
1                               -
-               2               -               -

In [32]:
# "right-right" scenario
t = BSTree()
for x in [1, 2, 3]:
t.pprint()

                               1
-                               2
-               -               -               3

In [33]:
# "right-left" scenario
t = BSTree()
for x in [1, 3, 2]:
t.pprint()

                               1
-                               3
-               -               2               -


## 5. Generalized AVL rebalancing (insertion)¶

In [34]:
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right

def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right

def rotate_left(self):
pass

@staticmethod
def height(n):
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))

@staticmethod
def rebalance(t):
if AVLTree.Node.height(t.left) > AVLTree.Node.height(t.right):
if AVLTree.Node.height(t.left.left) >= AVLTree.Node.height(t.left.right):
# left-left
print('left-left imbalance detected')
t.rotate_right()
else:
# left-right
print('left-right imbalance detected')
t.left.rotate_left()
t.rotate_right()
else:
pass

assert(val not in self)
if not node:
return AVLTree.Node(val)
elif val < node.val:
else:
if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node
self.size += 1

In [35]:
t = AVLTree()
for x in [10, 5, 1]:
t.pprint()

left-left imbalance detected
5
1                               10

In [36]:
# broken!
t = AVLTree()
for x in [10, 5, 1, 2, 3]:
t.pprint()

left-left imbalance detected
left-right imbalance detected
1
-                               5
-               -               2               10
-       -       -       -       -       3       -       -


## 5. Rebalancing on removal¶

In [37]:
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right

def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right

def rotate_left(self):
pass

@staticmethod
def height(n):
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))

@staticmethod
def rebalance(t):
if AVLTree.Node.height(t.left) > AVLTree.Node.height(t.right):
if AVLTree.Node.height(t.left.left) >= AVLTree.Node.height(t.left.right):
# left-left
print('left-left imbalance detected')
t.rotate_right()
else:
# left-right
print('left-right imbalance detected')
t.left.rotate_left()
t.rotate_right()
else:
pass

assert(val not in self)
if not node:
return AVLTree.Node(val)
elif val < node.val:
else:
if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node
self.size += 1

def __delitem__(self, val):
assert(val in self)
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
elif val > node.val:
node.right = delitem_rec(node.right)
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
# remove the largest value from the left subtree (t) as a replacement
# for the root value of this tree
t = node.left
if not t.right:
node.val = t.val
node.left = t.left
else:
n = t
while n.right.right:
n = n.right
t = n.right
n.right = t.left
node.val = t.val
if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node

self.root = delitem_rec(self.root)
self.size -= 1

In [41]:
t = AVLTree()
for x in [10, 5, 15, 2]:
t.pprint()

                               10
5                               15
2               -               -               -

In [42]:
del t[15]
t.pprint()

left-left imbalance detected
5
2                               10

In [50]:
t = AVLTree()
for x in range(31, 0, -1):
t.pprint()

left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
left-left imbalance detected
16
8                               24
4               12              20              28
2       6       10      14      18      22      26      30
1   3   5   7   9   11  13  15  17  19  21  23  25  27  29  31

In [51]:
del t[15]
del t[14]
t.pprint()

                               16
8                               24
4               12              20              28
2       6       10      13      18      22      26      30
1   3   5   7   9   11  -   -   17  19  21  23  25  27  29  31

In [52]:
# broken!
del t[16]
t.pprint()

                               13
8                               24
4               12              20              28
2       6       10      -       18      22      26      30
1   3   5   7   9   11  -   -   17  19  21  23  25  27  29  31