package xcollection

type TreeNode interface {
	PID() string
	ID() string
}

type Tree struct {
	Node  TreeNode `json:"chart"`
	Nodes []*Tree  `json:"charts"`
}

func NewTree(nodes []TreeNode) *Tree {
	var tree = &Tree{
		Node:  nil,
		Nodes: make([]*Tree, 0),
	}
	for i := range nodes {
		match := traverseAdd(tree, nodes[i])
		if !match {
			tree.Nodes = append(tree.Nodes, newTree(nodes[i]))
		}
	}
	return tree
}

func newTree(node TreeNode) *Tree {
	return &Tree{
		Node:  node,
		Nodes: make([]*Tree, 0),
	}
}

func (tree *Tree) Root() TreeNode {
	if tree.Node != nil {
		return tree.Node
	}
	if len(tree.Nodes) > 0 {
		return tree.Nodes[0].Node
	}
	return nil
}

// TreeNodePaths returns all the parents of the current node 1->5->7 , use time n*O(n)(need performance optimization)
func (tree *Tree) TreeNodePaths(node TreeNode) []TreeNode {
	treeNode := node
	result := make([]TreeNode, 0)
	for {
		if treeNode == nil {
			break
		}
		tmp := tree.find(treeNode, func(a, b TreeNode) bool {
			if a.ID() == b.PID() {
				return true
			}
			return false
		})
		result = append(result, treeNode)
		if tmp == nil {
			break
		}
		treeNode = tmp.Node
	}
	reserveResult := make([]TreeNode, 0)
	for i := len(result) - 1; i >= 0; i-- {
		reserveResult = append(reserveResult, result[i])
	}
	return reserveResult
}

// Add adds a node to the first matching parent tree if add success it return true
func (tree *Tree) Add(node TreeNode) bool {
	return traverseAdd(tree, node)
}

// AllChildNode returns all child nodes under Node, including itself
func (tree *Tree) AllChildNode(node TreeNode) []TreeNode {
	treeNode := tree.find(node, nil)
	if treeNode == nil {
		return []TreeNode{}
	}
	return tree.allChildNode(treeNode, nil)
}

// AllLeafNode returns all leaf node under Node ,if node is nil returns all leaf node under tree
func (tree *Tree) AllLeafNode(node TreeNode) []TreeNode {
	treeNode := tree
	if node != nil {
		treeNode = tree.find(node, nil)
	}
	if treeNode == nil {
		return []TreeNode{}
	}
	return tree.allChildNode(treeNode, func(node *Tree) bool {
		if len(node.Nodes) == 0 {
			return true
		}
		return false
	})
}

// Depth returns all child nodes under depth depth=[1:n]
func (tree *Tree) Depth(depth int) []TreeNode {
	treeNode := tree.find(tree.Root(), nil)
	if treeNode == nil {
		return []TreeNode{}
	}
	return tree.allChildByDepth(treeNode, depth)
}

// AllChildNodeByDepth returns all child nodes under depth Node
func (tree *Tree) AllChildNodeByDepth(node TreeNode, depth int) []TreeNode {
	treeNode := tree.find(node, nil)
	if treeNode == nil {
		return []TreeNode{}
	}
	return tree.allChildByDepth(treeNode, depth)
}

// Find query the node in this tree
func (tree *Tree) Find(node TreeNode, compared func(a, b TreeNode) bool) *Tree {
	return tree.find(node, compared)
}

// find query the node in this tree
func (tree *Tree) find(node TreeNode, compared func(a, b TreeNode) bool) *Tree {
	var stack []*Tree
	stack = append(stack, tree)
	var find *Tree
	for {
		if len(stack) == 0 {
			break
		}
		pop := stack[0]
		stack = stack[1:]
		stack = append(stack, pop.Nodes...)
		if pop == nil || pop.Node == nil {
			continue
		}
		if compared != nil {
			if compared(pop.Node, node) {
				find = pop
				break
			}
			continue
		}
		if pop.Node.ID() == node.ID() {
			find = pop
			break
		}
	}
	return find
}

// allChildNode 返回treeNode下所有子节点
func (tree *Tree) allChildNode(treeNode *Tree, filter func(node *Tree) bool) []TreeNode {
	var stack []*Tree
	stack = append(stack, treeNode)
	var res []TreeNode
	for {
		if len(stack) == 0 {
			break
		}
		pop := stack[0]
		stack = stack[1:]
		stack = append(stack, pop.Nodes...)
		if filter != nil && !filter(pop) {
			continue
		}
		res = append(res, pop.Node)
	}
	return res
}

// traverseAdd  递归添加
//
// tree 当前树
// node 判断的节点
func traverseAdd(tree *Tree, node TreeNode) bool {
	list := tree.Nodes
	var match bool = false
	for i := range list {
		id, pid := list[i].Node.ID(), node.PID()
		if pid == id {
			list[i].Nodes = append(list[i].Nodes, newTree(node))
			return true
		}
		if match || traverseAdd(list[i], node) {
			match = true
			break
		}
	}
	return match
}

// allChildByDepth 返回treeNode下指定深度的所有子节点 depth=[1:n]
func (tree *Tree) allChildByDepth(treeNode *Tree, depth int) []TreeNode {
	var stack []*Tree
	stack = append(stack, treeNode)
	var res []TreeNode
	if depth <= 0 {
		return res
	}
	if treeNode.Root() != nil && depth == 1 {
		return []TreeNode{treeNode.Root()}
	}
	curDepth := 1
	var depthStack []*Tree
	for {
		if len(stack) == 0 {
			break
		}
		pop := stack[0]
		stack = stack[1:]
		depthStack = append(depthStack, pop.Nodes...)
		if len(stack) == 0 {
			curDepth++
			stack = depthStack[:]
			depthStack = []*Tree{}
			if curDepth == depth {
				for i := range stack {
					res = append(res, stack[i].Node)
				}
				break
			}
		}
	}
	return res
}