# by default, only the result of the last expression in a cell is displayed after evaluation.
# the following forces display of *all* self-standing expressions in a cell.
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
class BSTree:
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def __init__(self):
self.size = 0
self.root = None
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return BSTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
return node
else:
node.right = add_rec(node.right)
return node
self.root = add_rec(self.root)
self.size += 1
def __contains__(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)
def __len__(self):
return self.size
def __delitem__(self, val):
assert(val in self)
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
return node
elif val > node.val:
node.right = delitem_rec(node.right)
return node
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
# remove the largest value from the left subtree as a replacement
# for the root value of this tree
t = node.left
if not t.right:
node.val = t.val
node.left = t.left
else:
n = t
while n.right.right:
n = n.right
t = n.right
n.right = t.left
node.val = t.val
return node
self.root = delitem_rec(self.root)
self.size -= 1
def pprint(self, width=64):
"""Attempts to pretty-print this tree's contents."""
height = self.height()
nodes = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n,level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height-1:
nodes.extend([(None, level+1), (None, level+1)])
repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
elif n:
if n.left or level < height-1:
nodes.append((n.left, level+1))
if n.right or level < height-1:
nodes.append((n.right, level+1))
repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
print(repr_str)
def height(self):
"""Returns the height of the longest branch of the tree."""
def height_rec(t):
if not t:
return 0
else:
return max(1+height_rec(t.left), 1+height_rec(t.right))
return height_rec(self.root)
t = BSTree()
for x in range(6):
t.add(x)
t.pprint()
0 - 1 - - - 2 - - - - - - - 3 - - - - - - - - - - - - - - - 4 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 5
import sys
sys.setrecursionlimit(100)
t = BSTree()
for x in range(100):
t.add(x)
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self): # fixes a left left imbalance
# swap the self.val with its left child val (n)
# move the left subtree to the right of self
#(first saving the old right of self)
# move the old right child of n is now the left child of right of self
# reattach the old right right subtree
n =self.left
self.val,n.val=n.val,self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
return node
else:
node.right = add_rec(node.right)
return node
self.root = add_rec(self.root)
self.size += 1
t = AVLTree()
for x in range(3, 0, -1):
t.add(x)
t.pprint()
3 2 - 1 - - -
t.root.rotate_right()
t.pprint()
2 1 3
t.add(0)
t.pprint()
2 1 3 0 - - -
t.add(-1)
t.pprint()
2 1 3 0 - - - -1 - - - - - - -
t.root.left.rotate_right()
t.pprint()
2 0 3 -1 1 - -
t.add(-4)
t.add(-5)
t.pprint()
2 0 3 -1 1 - - -4 - - - - - - - -5 - - - - - - - - - - - - - - -
t.root.left.left.rotate_right()
t.pprint()
2 0 3 -4 1 - - -5 -1 - - - - - -
t.root.rotate_right()
t.pprint()
0 -4 2 -5 -1 1 3
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right
@staticmethod
def height(n): #O(n) PROBLEM, worse than O(height of tree)
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
else:
node.right = add_rec(node.right)
if abs(AVLTree.Node.height(node.left)-AVLTree.Node.height(node.right)) >1: #imbalance at node
# assuming it is a left left imbalance
node.rotate_right()
return node
# detect and fix imbalance
self.root = add_rec(self.root)
self.size += 1
val = 50
t = AVLTree()
# (evaluate multiple times with ctrl-enter)
t.add(val)
val -= 1
t.pprint()
43 41 47 40 42 45 49 39 - - - 44 46 48 50
val = 0
t = AVLTree()
# (evaluate multiple times with ctrl-enter)
t.add(val)
val += 1
t.pprint()
# "left-left" scenario
t = BSTree()
for x in [3, 2, 1]:
t.add(x)
t.pprint()
3 2 - 1 - - -
# "left-right" scenario
t = BSTree()
for x in [3, 1, 2]:
t.add(x)
t.pprint()
3 1 - - 2 - -
# "right-right" scenario
t = BSTree()
for x in [1, 2, 3]:
t.add(x)
t.pprint()
1 - 2 - - - 3
# "right-left" scenario
t = BSTree()
for x in [1, 3, 2]:
t.add(x)
t.pprint()
1 - 3 - - 2 -
class AVLTree:
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
self.height = 1
def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right
n.height=AVLTree.getHeight(n)
self.height=AVLTree.getHeight(self)
def __init__(self):
self.size = 0
self.root = None
@staticmethod
def getHeight(n): # returns the height of node n O(1)
if not n:
return 0
else:
if n.left:
valSoFar=n.left.height
else:
valSoFar=0
if n.right:
if n.right.height>valSoFar:
valSoFar=n.right.height
return (valSoFar+1)
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
node.height=AVLTree.getHeight(node)
else:
node.right = add_rec(node.right)
node.height=AVLTree.getHeight(node)
if abs(AVLTree.getHeight(node.left)-AVLTree.getHeight(node.right)) >1: #imbalance at node
# assuming it is a left left imbalance
node.rotate_right()
return node
self.root = add_rec(self.root)
self.root.height=AVLTree.getHeight(self.root)
self.size += 1
def __contains__(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)
def __len__(self):
return self.size
def __delitem__(self, val):
assert(val in self)
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
return node
elif val > node.val:
node.right = delitem_rec(node.right)
return node
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
# remove the largest value from the left subtree as a replacement
# for the root value of this tree
t = node.left
if not t.right:
node.val = t.val
node.left = t.left
else:
n = t
while n.right.right:
n = n.right
t = n.right
n.right = t.left
node.val = t.val
return node
self.root = delitem_rec(self.root)
self.size -= 1
def pprint(self, width=64):
"""Attempts to pretty-print this tree's contents."""
height = self.height()
nodes = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n,level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height-1:
nodes.extend([(None, level+1), (None, level+1)])
repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
elif n:
if n.left or level < height-1:
nodes.append((n.left, level+1))
if n.right or level < height-1:
nodes.append((n.right, level+1))
repr_str += '{val:^{width}}'.format(val=str(n.val)+","+str(n.height), width=width//2**level)
print(repr_str)
def height(self): # O(number of nodes in the subtree)
"""Returns the height of the longest branch of the tree."""
def height_rec(t):
if not t:
return 0
else:
return max(1+height_rec(t.left), 1+height_rec(t.right))
return height_rec(self.root)
t = AVLTree()
for x in [3, 2, 1]:
t.add(x)
t.pprint()
2,2 1,1 3,1
t.add(0.5)
t.add(0.25)
t.pprint()
2,3 0.5,2 3,1 0.25,1 1,1 - -
class AVLTree:
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
self.height = 1
def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right
n.height=AVLTree.getHeight(n)
self.height=AVLTree.getHeight(self)
def rotate_left(self):
n = self.right
self.val, n.val = n.val, self.val
self.right, n.right, self.left, n.left = n.right, n.left, n, self.left
n.height=AVLTree.getHeight(n)
self.height=AVLTree.getHeight(self)
def __init__(self):
self.size = 0
self.root = None
@staticmethod
def getHeight(n): # returns the height of node n O(1)
if not n:
return 0
else:
if n.left:
valSoFar=n.left.height
else:
valSoFar=0
if n.right:
if n.right.height>valSoFar:
valSoFar=n.right.height
return (valSoFar+1)
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
node.height=AVLTree.getHeight(node)
else:
node.right = add_rec(node.right)
node.height=AVLTree.getHeight(node)
# INSERT IS DONE "node" is pointer to the current node of the
# recursive calls
if abs(AVLTree.getHeight(node.left)-AVLTree.getHeight(node.right))>1: #imbalance at node
# if left subtree has bigger height, and left grandchild subtree caused that height
# left left issue # backslash is line continuation
if AVLTree.getHeight(node.left)>AVLTree.getHeight(node.right) and \
AVLTree.getHeight(node.left.left)+1==AVLTree.getHeight(node.left):
# left left imbalance
node.rotate_right()
if AVLTree.getHeight(node.left)>AVLTree.getHeight(node.right) and \
AVLTree.getHeight(node.left.right)+1==AVLTree.getHeight(node.left):
# left right imbalance
node.left.rotate_left()
node.rotate_right()
if AVLTree.getHeight(node.right)>AVLTree.getHeight(node.left) and \
AVLTree.getHeight(node.right.right)+1==AVLTree.getHeight(node.right):
node.rotate_left()
if AVLTree.getHeight(node.right)>AVLTree.getHeight(node.left) and \
AVLTree.getHeight(node.right.left)+1==AVLTree.getHeight(node.right):
node.right.rotate_right()
node.rotate_left()
return node
self.root = add_rec(self.root)
self.root.height=AVLTree.getHeight(self.root)
self.size += 1
def __contains__(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)
def __len__(self):
return self.size
def __delitem__(self, val):
assert(val in self)
def delitem_rec(node):
nodes_to_fix=[node]
if val < node.val:
node.left = delitem_rec(node.left)
# return node
elif val > node.val:
node.right = delitem_rec(node.right)
# return node
else: # found the node with the val to delete
if not node.left and not node.right: # zero children
return None
elif node.left and not node.right: # 1 child, left
return node.left
elif node.right and not node.left: # 1 child, right
return node.right
else:
# predecessor
# remove the largest value from the left subtree as a replacement
# for the root value of this tree
t = node.left
if not t.right:
node.val = t.val
node.left = t.left
else:
nodes_to_fix.append(t)
n = t
while n.right.right:
n = n.right
nodes_to_fix.append(n)
t = n.right
n.right = t.left
node.val = t.val
# return node
# here I need to check for imbalance
while nodes_to_fix:
n=nodes_to_fix.pop()
if abs(AVLTree.getHeight(n.left)-AVLTree.getHeight(n.right))>1: #imbalance at node
# if left subtree has bigger height, and left grandchild subtree caused that height
# left left issue # backslash is line continuation
if AVLTree.getHeight(n.left)>AVLTree.getHeight(n.right) and \
AVLTree.getHeight(n.left.left)+1==AVLTree.getHeight(n.left):
# left left imbalance
n.rotate_right()
if AVLTree.getHeight(n.left)>AVLTree.getHeight(n.right) and \
AVLTree.getHeight(n.left.right)+1==AVLTree.getHeight(n.left):
# left right imbalance
n.left.rotate_left()
n.rotate_right()
if AVLTree.getHeight(n.right)>AVLTree.getHeight(n.left) and \
AVLTree.getHeight(n.right.right)+1==AVLTree.getHeight(n.right):
n.rotate_left()
if AVLTree.getHeight(n.right)>AVLTree.getHeight(n.left) and \
AVLTree.getHeight(n.right.left)+1==AVLTree.getHeight(n.right):
n.right.rotate_right()
n.rotate_left()
return node
self.root = delitem_rec(self.root)
self.size -= 1
def pprint(self, width=64):
"""Attempts to pretty-print this tree's contents."""
height = self.height()
nodes = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n,level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height-1:
nodes.extend([(None, level+1), (None, level+1)])
repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
elif n:
if n.left or level < height-1:
nodes.append((n.left, level+1))
if n.right or level < height-1:
nodes.append((n.right, level+1))
repr_str += '{val:^{width}}'.format(val=str(n.val)+","+str(n.height), width=width//2**level)
print(repr_str)
def height(self): # O(number of nodes in the subtree)
"""Returns the height of the longest branch of the tree."""
def height_rec(t):
if not t:
return 0
else:
return max(1+height_rec(t.left), 1+height_rec(t.right))
return height_rec(self.root)
t = AVLTree()
for x in [10, 5, 1]:
t.add(x)
t.pprint()
5,2 1,1 10,1
t = AVLTree()
for x in [10, 1, 5]:
t.add(x)
t.pprint()
5,2 1,1 10,1
t = AVLTree()
for x in [1, 5, 10]:
t.add(x)
t.pprint()
5,2 1,1 10,1
t = AVLTree()
for x in [1, 10, 5]:
t.add(x)
t.pprint()
5,2 1,1 10,1
t = AVLTree()
for x in [10, 5, 1, 2, 3]:
t.add(x)
t.pprint()
5,3 2,2 10,1 1,1 3,1 - -
t = AVLTree()
for x in [10, 5, 15, 2]:
t.add(x)
t.pprint()
10,3 5,2 15,1 2,1 - - -
del t[15]
t.pprint()
5,2 2,1 10,1
t = AVLTree()
for x in range(31, 0, -1):
t.add(x)
t.pprint()
16,5 8,4 24,4 4,3 12,3 20,3 28,3 2,2 6,2 10,2 14,2 18,2 22,2 26,2 30,2 1,1 3,1 5,1 7,1 9,1 11,113,115,117,119,121,123,125,127,129,131,1
del t[15]
t.pprint()
16,5 8,4 24,4 4,3 12,3 20,3 28,3 2,2 6,2 10,2 14,2 18,2 22,2 26,2 30,2 1,1 3,1 5,1 7,1 9,1 11,113,1 - 17,119,121,123,125,127,129,131,1
del t[14]
t.pprint()
16,5 8,4 24,4 4,3 12,3 20,3 28,3 2,2 6,2 10,2 13,1 18,2 22,2 26,2 30,2 1,1 3,1 5,1 7,1 9,1 11,1 - - 17,119,121,123,125,127,129,131,1
del t[13]
t.pprint()
16,5 8,4 24,4 4,3 11,3 20,3 28,3 2,2 6,2 10,2 12,1 18,2 22,2 26,2 30,2 1,1 3,1 5,1 7,1 9,1 - - - 17,119,121,123,125,127,129,131,1
del t[16]
t.pprint()
12,5 8,4 24,4 4,3 10,2 20,3 28,3 2,2 6,2 9,1 11,1 18,2 22,2 26,2 30,2 1,1 3,1 5,1 7,1 - - - - 17,119,121,123,125,127,129,131,1