class BSTree:
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def __init__(self):
self.size = 0
self.root = None
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return BSTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
return node
else:
node.right = add_rec(node.right)
return node
self.root = add_rec(self.root)
self.size += 1
def __contains__(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)
def __len__(self):
return self.size
def __delitem__(self, val):
assert(val in self)
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
return node
elif val > node.val:
node.right = delitem_rec(node.right)
return node
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
to_del = node.left
if not to_del.right:
node.left = to_del.left
else:
par = to_del
to_del = par.right
while to_del.right:
par = par.right
to_del = to_del.right
# to_del refers to the right-most node, and par to its parent
par.right = to_del.left
node.val = to_del.val
return node
self.root = delitem_rec(self.root)
self.size -= 1
def pprint(self, width=64):
"""Attempts to pretty-print this tree's contents."""
height = self.height()
nodes = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n,level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height-1:
nodes.extend([(None, level+1), (None, level+1)])
repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
elif n:
if n.left or level < height-1:
nodes.append((n.left, level+1))
if n.right or level < height-1:
nodes.append((n.right, level+1))
repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
print(repr_str)
def height(self):
"""Returns the height of the longest branch of the tree."""
def height_rec(t):
if not t:
return 0
else:
return max(1+height_rec(t.left), 1+height_rec(t.right))
return height_rec(self.root)
t = BSTree()
for x in range(6):
t.add(x)
t.pprint()
0 - 1 - - - 2 - - - - - - - 3 - - - - - - - - - - - - - - - 4 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 5
import sys
sys.setrecursionlimit(100)
t = BSTree()
for x in range(100):
t.add(x)
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self):
p = self
c = p.left
p.val, c.val = c.val, p.val
p.left, p.right, c.left, c.right = c.left, c, c.right, p.right
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
return node
else:
node.right = add_rec(node.right)
return node
self.root = add_rec(self.root)
self.size += 1
t = AVLTree()
for x in range(6, 0, -1):
t.add(x)
t.pprint()
6 5 - 4 - - - 3 - - - - - - - 2 - - - - - - - - - - - - - - - 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
t.root.rotate_right()
t.pprint()
5 4 6 3 - - - 2 - - - - - - - 1 - - - - - - - - - - - - - - -
t.root.rotate_right()
t.pprint()
4 3 5 2 - - 6 1 - - - - - - -
t.root.left.rotate_right()
t.pprint()
4 2 5 1 3 - 6
t.root.rotate_right()
t.pprint()
2 1 4 - - 3 5 - - - - - - - 6
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self):
p = self
c = self.left
p.val, c.val = c.val, p.val
p.left, p.right, c.left, c.right = c.left, c, c.right, p.right
@staticmethod
def height(n):
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))
def add(self, val):
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
else:
node.right = add_rec(node.right)
# detect and fix imbalance in `node` before returning
if AVLTree.Node.height(node.left) > AVLTree.Node.height(node.right) + 1:
#print('found a left-side imbalance, rotating right about {}'.format(node.val))
node.rotate_right()
return node
self.root = add_rec(self.root)
self.size += 1
val = 50
t = AVLTree()
# (evaluate multiple times with ctrl-enter)
t.add(val)
val -= 1
t.pprint()
35 27 43 25 31 39 47 23 26 29 33 37 41 45 49 22 24 - - 28 30 32 34 36 38 40 42 44 46 48 50
val = 0
t = AVLTree()
# (evaluate multiple times with ctrl-enter)
t.add(val)
val += 1
t.pprint()
0 - 1 - - - 2 - - - - - - - 3 - - - - - - - - - - - - - - - 4
# "left-left" scenario
t = BSTree()
for x in [3, 2, 1]:
t.add(x)
t.pprint()
3 2 - 1 - - -
# "left-left" scenario
t = AVLTree()
for x in [3, 2, 1]:
t.add(x)
t.pprint()
2 1 3
# "left-right" scenario
t = BSTree()
for x in [3, 1, 2]:
t.add(x)
t.pprint()
3 1 - - 2 - -
# "left-right" scenario
t = AVLTree()
for x in [3, 1, 2]:
t.add(x)
t.pprint()
1 - 3 - - 2 -
# "right-right" scenario
t = BSTree()
for x in [1, 2, 3]:
t.add(x)
t.pprint()
1 - 2 - - - 3
# "right-left" scenario
t = BSTree()
for x in [1, 3, 2]:
t.add(x)
t.pprint()
1 - 3 - - 2 -
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self):
p = self
c = self.left
p.val, c.val = c.val, p.val
p.left, p.right, c.left, c.right = c.left, c, c.right, p.right
def rotate_left(self):
pass
@staticmethod
def height(n):
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))
@staticmethod
def rebalance(node):
if AVLTree.Node.height(node.left) > AVLTree.Node.height(node.right):
if AVLTree.Node.height(node.left.left) >= AVLTree.Node.height(node.left.right):
# left-left
print('left-left imbalance detected')
node.rotate_right()
else:
# left-right
print('left-right imbalance detected')
node.left.rotate_left()
node.rotate_right()
else:
# right branch imbalance tests needed
pass
def add(self, val): # O(log N)
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
else:
node.right = add_rec(node.right)
if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2: # detect imbalance
AVLTree.rebalance(node)
return node
self.root = add_rec(self.root)
self.size += 1
t = AVLTree()
for x in [10, 5, 1]:
t.add(x)
t.pprint()
left-left imbalance detected 5 1 10
# broken!
t = AVLTree()
for x in [10, 5, 1, 2, 3]:
t.add(x)
t.pprint()
left-left imbalance detected left-right imbalance detected 1 - 5 - - 2 10 - - - - - 3 - -
class AVLTree(BSTree):
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self): # O(1)
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n, self.right
def rotate_left(self): # O(1)
pass
@staticmethod
def height(n):
if not n:
return 0
else:
return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))
@staticmethod
def rebalance(node):
if AVLTree.Node.height(node.left) > AVLTree.Node.height(node.right):
if AVLTree.Node.height(node.left.left) >= AVLTree.Node.height(node.left.right):
# left-left
#print('left-left imbalance detected')
node.rotate_right()
else:
# left-right
#print('left-right imbalance detected')
node.left.rotate_left()
node.rotate_right()
else:
pass
def add(self, val): # O(log N)
assert(val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
else:
node.right = add_rec(node.right)
# detect and fix imbalance
if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node
self.root = add_rec(self.root)
self.size += 1
def __delitem__(self, val): # O(log N)
assert(val in self)
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
elif val > node.val:
node.right = delitem_rec(node.right)
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
to_del = node.left
if not to_del.right:
node.left = to_del.left
else:
par = to_del
to_del = par.right
to_fix = [par]
while to_del.right:
par = par.right
to_fix.append(par)
to_del = to_del.right
# to_del refers to the right-most node, and par to its parent
par.right = to_del.left
# to_fix contains all the nodes I need to check for rebalancing
for n in to_fix[::-1]: # traverse list in reverse
if abs(AVLTree.Node.height(n.left) - AVLTree.Node.height(n.right)) >= 2:
AVLTree.rebalance(n)
node.val = to_del.val
# detect and fix imbalance (recursively)
if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node
self.root = delitem_rec(self.root)
self.size -= 1
def __iter__(self):
def iter_rec(n):
if n:
yield from iter_rec(n.left)
yield n.val
yield from iter_rec(n.right)
return iter_rec(self.root)
t = AVLTree()
for x in [10, 5, 15, 2]:
t.add(x)
t.pprint()
10 5 15 2 - - -
del t[15]
t.pprint()
5 2 10
t = AVLTree()
for x in range(31, 0, -1):
t.add(x)
t.pprint()
16 8 24 4 12 20 28 2 6 10 14 18 22 26 30 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31
del t[15]
del t[14]
t.pprint()
16 8 24 4 12 20 28 2 6 10 13 18 22 26 30 1 3 5 7 9 11 - - 17 19 21 23 25 27 29 31
del t[16]
t.pprint()
13 8 24 4 10 20 28 2 6 9 12 18 22 26 30 1 3 5 7 - - 11 - 17 19 21 23 25 27 29 31
for x in t:
print(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31