tree_test.go 2.8 KB
package xcollection

import (
	"fmt"
	"github.com/stretchr/testify/assert"
	"testing"
)

func prepare() []struct {
	Input   []TreeNode
	Text    string
	Except  []string
	Except2 []string
} {
	return []struct {
		Input   []TreeNode
		Text    string
		Except  []string
		Except2 []string
	}{
		{
			Input: []TreeNode{
				&st{Id: 1, Pid: 0},
				&st{Id: 2, Pid: 1}, &st{Id: 3, Pid: 1}, &st{Id: 4, Pid: 1},
				&st{Id: 5, Pid: 3},
				&st{Id: 6, Pid: 5}, &st{Id: 7, Pid: 5}},
			Text: `
树形结构:	 
	 1
2    3    4
     5
    6 7 
`,
			Except:  []string{"5", "6", "7"},
			Except2: []string{"2", "4", "6", "7"},
		},
	}
}

func Test_Tree(t *testing.T) {
	table := prepare()
	for i := range table {
		tree := NewTree(table[i].Input)
		out := tree.AllChildNode(&st{Id: 5, Pid: 3})
		var res []string = treeNodeResults(out)
		assert.Equal(t, res, table[i].Except)

		out = tree.AllLeafNode(nil) //tree.Root()
		res = treeNodeResults(out)
		assert.Equal(t, res, table[i].Except2)

		root := tree.Root()
		assert.Equal(t, root.ID(), "1")

		//tree.Add(&st{Id:10,Pid: 7})
		//
		//out = tree.AllLeafNode(tree.Root())
		//res  = treeNodeResults(out)
		//assert.Equal(t, res, []string{"2", "4", "6", "10"})

		out = tree.TreeNodePaths(&st{Id: 7, Pid: 5})
		res = treeNodeResults(out)
		assert.Equal(t, res, []string{"1", "3", "5", "7"})

	}
}

func Test_TreeNodeByDepth(t *testing.T) {
	input := []TreeNode{
		&st{Id: 1, Pid: 0},
		&st{Id: 2, Pid: 1}, &st{Id: 3, Pid: 1}, &st{Id: 4, Pid: 1},
		&st{Id: 5, Pid: 3},
		&st{Id: 6, Pid: 5}, &st{Id: 7, Pid: 5},
		&st{Id: 8, Pid: 6}, &st{Id: 9, Pid: 6}, &st{Id: 10, Pid: 6}, &st{Id: 11, Pid: 7}, &st{Id: 12, Pid: 7},
	}

	tree := NewTree(input)
	/*
			树形结构:
				 1
			2    3    4
			     5
			  6          7
		   8  9  10   11   12
	*/
	var out []TreeNode
	var res []string
	out = tree.AllChildNodeByDepth(&st{Id: 5, Pid: 3}, 2)
	res = treeNodeResults(out)
	assert.Equal(t, []string{"6", "7"}, res)
	out = tree.AllChildNodeByDepth(tree.Root(), 1)
	res = treeNodeResults(out)
	assert.Equal(t, []string{"1"}, res)
	out = tree.AllChildNodeByDepth(tree.Root(), 2)
	res = treeNodeResults(out)
	assert.Equal(t, []string{"2", "3", "4"}, res)
	out = tree.AllChildNodeByDepth(tree.Root(), 3)
	res = treeNodeResults(out)
	assert.Equal(t, []string{"5"}, res)
	out = tree.AllChildNodeByDepth(tree.Root(), 4)
	res = treeNodeResults(out)
	assert.Equal(t, []string{"6", "7"}, res)
	out = tree.AllChildNodeByDepth(tree.Root(), 5)
	res = treeNodeResults(out)
	assert.Equal(t, []string{"8", "9", "10", "11", "12"}, res)
}

type st struct {
	Id  int64
	Pid int64
}

func (t *st) PID() int64 {
	return t.Id
}
func (t *st) ID() int64 {
	return t.Pid
}

func treeNodeResults(nodes []TreeNode) []string {
	var res []string
	for i := range nodes {
		res = append(res, fmt.Sprintf("%d", nodes[i].ID()))
	}
	return res
}