Diameter of a Binary Tree


In [9]:
class Node():
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.p = None
    def __str__(self):
        return "(" + str(self.value) + ")-"

In [3]:
def height(node):
    if node is None:
        return 0
    
    lheight = height(node.left)
    rheight = height(node.right)
    
    return 1 + max(lheight, rheight)
    
def diameter(node):
    if node is None:
        return 0
    
    lh = height(node.left)
    rh = height(node.right)
    
    ldiameter = diameter(node.left)
    rdiameter = diameter(node.right)
    
    return max(lh + rh + 1, max(ldiameter, rdiameter))
    
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)


print(height(root))

print(diameter(root))


3
4

Basic Operation


In [30]:
def inorder(node):
    if node is not None:
        print("[", end='')
        inorder(node.left)
        print("(" + str(node.value) + ")", end='')
        inorder(node.right)
        print("]", end='')
        
def insert(root, node):
    
    p = None
    cur = root
    while cur != None:
        p = cur
        if node.value < cur.value:
            cur = cur.left
        else:
            cur = cur.right
    print("node(" + str(p))
    if node.value < p.value:
        p.left = node
        print("insert at left:" + str(node))
    else:
        p.right = node
        print("insert at right:" + str(node))
    node.p = p
        
root = Node(6)

insert(root, Node(5))
insert(root, Node(7))
insert(root, Node(8))
insert(root, Node(2))
insert(root, Node(5))

inorder(root)


node((6)-
insert at left:(5)-
node((6)-
insert at right:(7)-
node((7)-
insert at right:(8)-
node((5)-
insert at left:(2)-
node((5)-
insert at right:(5)-
[[[(2)](5)[(5)]](6)[(7)[(8)]]]

In [31]:
def minimum(node):
    cur = node
    while cur.left != None:
        cur = cur.left
    return cur

def search(node, value):
    cur = node
    while cur != None:
        if value == cur.value:
            return cur
        elif value < cur.value: 
            cur = cur.left
        else:
            cur = cur.right
    return cur


def successor(node):
    if node.right != None:
        return minimum(node)
    
    cur = node
    p = node.p
    while (p != None) and (cur == p.right):
        cur = p
        p = p.p
    return p

root = Node(15)

insert(root, Node(6))
insert(root, Node(18))
insert(root, Node(17))
insert(root, Node(20))
insert(root, Node(3))
insert(root, Node(2))
insert(root, Node(4))
insert(root, Node(7))
insert(root, Node(13))
insert(root, Node(9))

inorder(root)

target = search(root, 13)
print()
print(target)
print(successor(target))


node((15)-
insert at left:(6)-
node((15)-
insert at right:(18)-
node((18)-
insert at left:(17)-
node((18)-
insert at right:(20)-
node((6)-
insert at left:(3)-
node((3)-
insert at left:(2)-
node((3)-
insert at right:(4)-
node((6)-
insert at right:(7)-
node((7)-
insert at right:(13)-
node((13)-
insert at left:(9)-
[[[[(2)](3)[(4)]](6)[(7)[[(9)](13)]]](15)[[(17)](18)[(20)]]]
(13)-
(15)-

In [36]:
def kth_small(node, k, i):
    if node != None:
        retNode = kth_small(node.left, k, i)
        if retNode != None:
            return retNode
        i[0] = i[0] + 1
        print(i[0], node.value)
        if i[0] == k:
            print("!")
            return node
        retNode = kth_small(node.right, k, i)
        if retNode != None:
            return retNode
        return retNode
    else:
        return None

root = Node(15)
insert(root, Node(6))
insert(root, Node(18))
insert(root, Node(17))
insert(root, Node(20))
insert(root, Node(3))
insert(root, Node(2))
insert(root, Node(4))
insert(root, Node(7))
insert(root, Node(13))
insert(root, Node(9))

print(kth_small(root, 4, [0]))


node((15)-
insert at left:(6)-
node((15)-
insert at right:(18)-
node((18)-
insert at left:(17)-
node((18)-
insert at right:(20)-
node((6)-
insert at left:(3)-
node((3)-
insert at left:(2)-
node((3)-
insert at right:(4)-
node((6)-
insert at right:(7)-
node((7)-
insert at right:(13)-
node((13)-
insert at left:(9)-
1 2
2 3
3 4
4 6
!
(6)-

In [ ]:
ret = None
count = 0
def kth_small2(node, k):
    global count
    global ret
    if node != None:
        retNode = kth_small2(node.left, k)
        if ret != None:
            return
        count += 1
        print(count, node.value)
        if count == k:
            print("!")
            ret = node
            return
        retNode = kth_small2(node.right, k)
        if ret != None:
            return
        return
    else:
        return
    
print(kth_small2(root, 4))
print(ret)

In [ ]: