作者 yangfu

feat: utils add tree

  1 +package utils
  2 +
  3 +type TreeNode interface {
  4 + PID() string
  5 + ID() string
  6 +}
  7 +
  8 +type Tree struct {
  9 + Node TreeNode `json:"node"`
  10 + Nodes []*Tree `json:"nodes"`
  11 +}
  12 +
  13 +func NewTree(nodes []TreeNode) *Tree {
  14 + var tree = &Tree{
  15 + Node: nil,
  16 + Nodes: make([]*Tree, 0),
  17 + }
  18 + for i := range nodes {
  19 + match := traverseAdd(tree, nodes[i])
  20 + if !match {
  21 + tree.Nodes = append(tree.Nodes, newTree(nodes[i]))
  22 + }
  23 + }
  24 + return tree
  25 +}
  26 +
  27 +func newTree(node TreeNode) *Tree {
  28 + return &Tree{
  29 + Node: node,
  30 + Nodes: make([]*Tree, 0),
  31 + }
  32 +}
  33 +
  34 +func (tree *Tree) Root() TreeNode {
  35 + if tree.Node != nil {
  36 + return tree.Node
  37 + }
  38 + if len(tree.Nodes) > 0 {
  39 + return tree.Nodes[0].Node
  40 + }
  41 + return nil
  42 +}
  43 +
  44 +// TreeNodePaths returns all the parents of the current node 1->5->7 , use time n*O(n)(need performance optimization)
  45 +func (tree *Tree) TreeNodePaths(node TreeNode) []TreeNode {
  46 + treeNode := node
  47 + result := make([]TreeNode, 0)
  48 + for {
  49 + if treeNode == nil {
  50 + break
  51 + }
  52 + tmp := tree.find(treeNode, func(a, b TreeNode) bool {
  53 + if a.ID() == b.PID() {
  54 + return true
  55 + }
  56 + return false
  57 + })
  58 + result = append(result, treeNode)
  59 + if tmp == nil {
  60 + break
  61 + }
  62 + treeNode = tmp.Node
  63 + }
  64 + reserveResult := make([]TreeNode, 0)
  65 + for i := len(result) - 1; i >= 0; i-- {
  66 + reserveResult = append(reserveResult, result[i])
  67 + }
  68 + return reserveResult
  69 +}
  70 +
  71 +// Add adds a node to the first matching parent tree if add success it return true
  72 +func (tree *Tree) Add(node TreeNode) bool {
  73 + return traverseAdd(tree, node)
  74 +}
  75 +
  76 +// AllChildNode returns all child nodes under Node, including itself
  77 +func (tree *Tree) AllChildNode(node TreeNode) []TreeNode {
  78 + treeNode := tree.find(node, nil)
  79 + if treeNode == nil {
  80 + return []TreeNode{}
  81 + }
  82 + return tree.allChildNode(treeNode, nil)
  83 +}
  84 +
  85 +//AllLeafNode returns all leaf node under Node ,if node is nil returns all leaf node under tree
  86 +func (tree *Tree) AllLeafNode(node TreeNode) []TreeNode {
  87 + treeNode := tree
  88 + if node != nil {
  89 + treeNode = tree.find(node, nil)
  90 + }
  91 + if treeNode == nil {
  92 + return []TreeNode{}
  93 + }
  94 + return tree.allChildNode(treeNode, func(node *Tree) bool {
  95 + if len(node.Nodes) == 0 {
  96 + return true
  97 + }
  98 + return false
  99 + })
  100 +}
  101 +
  102 +// find query the node in this tree
  103 +func (tree *Tree) find(node TreeNode, compared func(a, b TreeNode) bool) *Tree {
  104 + var stack []*Tree
  105 + stack = append(stack, tree)
  106 + var find *Tree
  107 + for {
  108 + if len(stack) == 0 {
  109 + break
  110 + }
  111 + pop := stack[0]
  112 + stack = stack[1:]
  113 + stack = append(stack, pop.Nodes...)
  114 + if pop == nil || pop.Node == nil {
  115 + continue
  116 + }
  117 + if compared != nil {
  118 + if compared(pop.Node, node) {
  119 + find = pop
  120 + break
  121 + }
  122 + continue
  123 + }
  124 + if pop.Node.ID() == node.ID() {
  125 + find = pop
  126 + break
  127 + }
  128 + }
  129 + return find
  130 +}
  131 +
  132 +// allChildNode 返回treeNode下所有子节点
  133 +func (tree *Tree) allChildNode(treeNode *Tree, filter func(node *Tree) bool) []TreeNode {
  134 + var stack []*Tree
  135 + stack = append(stack, treeNode)
  136 + var res []TreeNode
  137 + for {
  138 + if len(stack) == 0 {
  139 + break
  140 + }
  141 + pop := stack[0]
  142 + stack = stack[1:]
  143 + stack = append(stack, pop.Nodes...)
  144 + if filter != nil && !filter(pop) {
  145 + continue
  146 + }
  147 + res = append(res, pop.Node)
  148 + }
  149 + return res
  150 +}
  151 +
  152 +// traverseAdd 递归添加
  153 +//
  154 +// tree 当前树
  155 +// node 判断的节点
  156 +func traverseAdd(tree *Tree, node TreeNode) bool {
  157 + list := tree.Nodes
  158 + var match bool = false
  159 + for i := range list {
  160 + id, pid := list[i].Node.ID(), node.PID()
  161 + if pid == id {
  162 + list[i].Nodes = append(list[i].Nodes, newTree(node))
  163 + return true
  164 + }
  165 + if match || traverseAdd(list[i], node) {
  166 + match = true
  167 + break
  168 + }
  169 + }
  170 + return match
  171 +}
  1 +package utils
  2 +
  3 +import (
  4 + "github.com/stretchr/testify/assert"
  5 + "strconv"
  6 + "testing"
  7 +)
  8 +
  9 +func Test_Tree(t *testing.T) {
  10 + table := []struct {
  11 + Input []TreeNode
  12 + Text string
  13 + Except []string
  14 + Except2 []string
  15 + }{
  16 + {
  17 + Input: []TreeNode{
  18 + &st{Id: 1, Pid: 0},
  19 + &st{Id: 2, Pid: 1}, &st{Id: 3, Pid: 1}, &st{Id: 4, Pid: 1},
  20 + &st{Id: 5, Pid: 3},
  21 + &st{Id: 6, Pid: 5}, &st{Id: 7, Pid: 5}},
  22 + Text: `
  23 +树形结构:
  24 + 1
  25 +2 3 4
  26 + 5
  27 + 6 7
  28 +`,
  29 + Except: []string{"5", "6", "7"},
  30 + Except2: []string{"2", "4", "6", "7"},
  31 + },
  32 + }
  33 +
  34 + for i := range table {
  35 + tree := NewTree(table[i].Input)
  36 + out := tree.AllChildNode(&st{Id: 5, Pid: 3})
  37 + var res []string = treeNodeResults(out)
  38 + assert.Equal(t, res, table[i].Except)
  39 +
  40 + out = tree.AllLeafNode(nil) //tree.Root()
  41 + res = treeNodeResults(out)
  42 + assert.Equal(t, res, table[i].Except2)
  43 +
  44 + root := tree.Root()
  45 + assert.Equal(t, root.ID(), "1")
  46 +
  47 + //tree.Add(&st{Id:10,Pid: 7})
  48 + //
  49 + //out = tree.AllLeafNode(tree.Root())
  50 + //res = treeNodeResults(out)
  51 + //assert.Equal(t, res, []string{"2", "4", "6", "10"})
  52 +
  53 + out = tree.TreeNodePaths(&st{Id: 7, Pid: 5})
  54 + res = treeNodeResults(out)
  55 + assert.Equal(t, res, []string{"1", "3", "5", "7"})
  56 + }
  57 +}
  58 +
  59 +type st struct {
  60 + Id int
  61 + Pid int
  62 +}
  63 +
  64 +func (t *st) PID() string {
  65 + return strconv.Itoa(t.Pid)
  66 +}
  67 +func (t *st) ID() string {
  68 + return strconv.Itoa(t.Id)
  69 +}
  70 +
  71 +func treeNodeResults(nodes []TreeNode) []string {
  72 + var res []string
  73 + for i := range nodes {
  74 + res = append(res, nodes[i].ID())
  75 + }
  76 + return res
  77 +}