给定一个二叉搜索树(BST),找到树中第 K 小的节点

摘要: 给定一个二叉搜索树(BST),找到树中第 K 小的节点

题目:给定一个二叉搜索树(BST),找到树中第 K 小的节点。

出题人:阿里巴巴出题专家:文景/阿里云 CDN 资深技术专家

参考答案:


考察点

1. 基础数据结构的理解和编码能力

2. 递归使用


 示例

       5

      / \

     3   6

    / \

   2   4

  /

 1

说明:保证输入的 K 满足 1<=K<=(节点数目)

解法1:树相关的题目,第一眼就想到递归求解,左右子树分别遍历。联想到二叉搜索树的性质,root 大于左子树,小于右子树,如果左子树的节点数目等于 K-1,那么 root 就是结果,否则如果左子树节点数目小于 K-1,那么结果必然在右子树,否则就在左子树。因此在搜索的时候同时返回节点数目,跟 K 做对比,就能得出结果了。


/**

 * Definition for a binary tree node.

 **/


public class TreeNode {

    int val;

    TreeNode left;

    TreeNode right;

    TreeNode(int x) { val = x; }

}


class Solution {

    private class ResultType {

    

        boolean found;  // 是否找到

        

        int val;  // 节点数目

        ResultType(boolean found, int val) {

            this.found = found;

            this.val = val;

        }

    }


    public int kthSmallest(TreeNode root, int k) {

        return kthSmallestHelper(root, k).val;

    }


    private ResultType kthSmallestHelper(TreeNode root, int k) {

        if (root == null) {

            return new ResultType(false, 0);

        }


        ResultType left = kthSmallestHelper(root.left, k);


        // 左子树找到,直接返回

        if (left.found) {

            return new ResultType(true, left.val);

        }


        // 左子树的节点数目 = K-1,结果为 root 的值

        if (k - left.val == 1) {

            return new ResultType(true, root.val);

        }


        // 右子树寻找

        ResultType right = kthSmallestHelper(root.right, k - left.val - 1);

        if (right.found) {

            return new ResultType(true, right.val);

        }


        // 没找到,返回节点总数

        return new ResultType(false, left.val + 1 + right.val);

    }

}


解法2:基于二叉搜索树的特性,在中序遍历的结果中,第k个元素就是本题的解。

最差的情况是k节点是bst的最右叶子节点,不过`每个节点的遍历次数最多是1次`。

遍历并不是需要全部做完,使用计数的方式,找到第k个元素就可以退出。

下面是go的一个简单实现。


// BST is binary search tree

type BST struct {

key, value  int

left, right *BST

}


func (bst *BST) setLeft(b *BST) {

bst.left = b

}


func (bst *BST) setRight(b *BST) {

bst.right = b

}


// count 查找bst第k个节点的值,未找到就返回0

func count(bst *BST, k int) int {

if k < 1 {

return 0

}


c := 0

ok, value := countRecursive(bst, &c, k)


if ok {

return value

}


return 0

}


// countRecurisive 对bst使用中序遍历

// 用计数方式控制退出遍历,参数c就是已遍历节点数

func countRecursive(bst *BST, c *int, k int) (bool, int) {

if bst.left != nil {

ok, value := countRecursive(bst.left, c, k)

if ok {

return ok, value

}

}


if *c == k-1 {

return true, bst.value

}


*c++


if bst.right != nil {

ok, value := countRecursive(bst.right, c, k)

if ok {

return ok, value

}

}


return false, 0

}


// 下面是测试代码,覆盖了退化的情况和普通bst


func createBST1() *BST {

b1 := &BST{key: 1, value: 10}

b2 := &BST{key: 2, value: 20}

b3 := &BST{key: 3, value: 30}

b4 := &BST{key: 4, value: 40}

b5 := &BST{key: 5, value: 50}

b6 := &BST{key: 6, value: 60}

b7 := &BST{key: 7, value: 70}

b8 := &BST{key: 8, value: 80}

b9 := &BST{key: 9, value: 90}


b9.setLeft(b8)

b8.setLeft(b7)

b7.setLeft(b6)

b6.setLeft(b5)

b5.setLeft(b4)

b4.setLeft(b3)

b3.setLeft(b2)

b2.setLeft(b1)


return b9

}


func createBST2() *BST {

b1 := &BST{key: 1, value: 10}

b2 := &BST{key: 2, value: 20}

b3 := &BST{key: 3, value: 30}

b4 := &BST{key: 4, value: 40}

b5 := &BST{key: 5, value: 50}

b6 := &BST{key: 6, value: 60}

b7 := &BST{key: 7, value: 70}

b8 := &BST{key: 8, value: 80}

b9 := &BST{key: 9, value: 90}


b1.setRight(b2)

b2.setRight(b3)

b3.setRight(b4)

b4.setRight(b5)

b5.setRight(b6)

b6.setRight(b7)

b7.setRight(b8)

b8.setRight(b9)


return b1

}


func createBST3() *BST {

b1 := &BST{key: 1, value: 10}

b2 := &BST{key: 2, value: 20}

b3 := &BST{key: 3, value: 30}

b4 := &BST{key: 4, value: 40}

b5 := &BST{key: 5, value: 50}

b6 := &BST{key: 6, value: 60}

b7 := &BST{key: 7, value: 70}

b8 := &BST{key: 8, value: 80}

b9 := &BST{key: 9, value: 90}


b5.setLeft(b3)

b5.setRight(b7)

b3.setLeft(b2)

b3.setRight(b4)

b2.setLeft(b1)

b7.setLeft(b6)

b7.setRight(b8)

b8.setRight(b9)


return b5

}


func createBST4() *BST {

b := &BST{key: 1, value: 10}

last := b


for i := 2; i < 100000; i++ {

n := &BST{key: i, value: i * 10}

last.setRight(n)


last = n

}


return b

}


func createBST5() *BST {

b := &BST{key: 99999, value: 999990}

last := b


for i := 99998; i > 0; i-- {

n := &BST{key: i, value: i * 10}

last.setLeft(n)


last = n

}


return b

}


func createBST6() *BST {

b := &BST{key: 50000, value: 500000}

last := b


for i := 49999; i > 0; i-- {

n := &BST{key: i, value: i * 10}

last.setLeft(n)


last = n

}


last = b


for i := 50001; i < 100000; i++ {

n := &BST{key: i, value: i * 10}

last.setRight(n)


last = n

}


return b

}


func TestK(t *testing.T) {

bst1 := createBST1()

bst2 := createBST2()

bst3 := createBST3()

bst4 := createBST4()


check(t, bst1, 1, 10)

check(t, bst1, 2, 20)

check(t, bst1, 3, 30)

check(t, bst1, 4, 40)

check(t, bst1, 5, 50)

check(t, bst1, 6, 60)

check(t, bst1, 7, 70)

check(t, bst1, 8, 80)

check(t, bst1, 9, 90)


check(t, bst2, 1, 10)

check(t, bst2, 2, 20)

check(t, bst2, 3, 30)

check(t, bst2, 4, 40)

check(t, bst2, 5, 50)

check(t, bst2, 6, 60)

check(t, bst2, 7, 70)

check(t, bst2, 8, 80)

check(t, bst2, 9, 90)


check(t, bst3, 1, 10)

check(t, bst3, 2, 20)

check(t, bst3, 3, 30)

check(t, bst3, 4, 40)

check(t, bst3, 5, 50)

check(t, bst3, 6, 60)

check(t, bst3, 7, 70)

check(t, bst3, 8, 80)

check(t, bst3, 9, 90)


check(t, bst4, 1, 10)

check(t, bst4, 2, 20)

check(t, bst4, 3, 30)

check(t, bst4, 4, 40)

check(t, bst4, 5, 50)

check(t, bst4, 6, 60)

check(t, bst4, 7, 70)

check(t, bst4, 8, 80)

check(t, bst4, 9, 90)


check(t, bst4, 99991, 999910)

check(t, bst4, 99992, 999920)

check(t, bst4, 99993, 999930)

check(t, bst4, 99994, 999940)

check(t, bst4, 99995, 999950)

check(t, bst4, 99996, 999960)

check(t, bst4, 99997, 999970)

check(t, bst4, 99998, 999980)

check(t, bst4, 99999, 999990)

}


func check(t *testing.T, b *BST, k, value int) {

t.Helper()


checkCall(t, b, k, value, count)

// 此处可添加其他解法的实现

}


func checkCall(t *testing.T, b *BST, k, value int, find func(bst *BST, kth int) int) {

t.Helper()


got := find(b, k)

if got != value {

t.Fatalf("want:%d, got:%d", value, got)

}

}


网友留言评论

0条评论