package domain

import (
	"fmt"
	"strconv"
)

const (
	StructTree = "tree"
	StructList = "list"
)

type TreeNode interface {
	PID() string
	ID() string
}

type Tree struct {
	Node  TreeNode `json:"node"`
	Nodes []*Tree  `json:"nodes"`
}

func NewTrees(nodes []TreeNode) *Tree {
	var tree = &Tree{
		Node:  nil,
		Nodes: make([]*Tree, 0),
	}
	for i := range nodes {
		match := traverse(tree, nodes[i])
		if !match {
			tree.Nodes = append(tree.Nodes, NewTree(nodes[i]))
		}
	}
	return tree
}

func NewTree(node TreeNode) *Tree {
	return &Tree{
		Node:  node,
		Nodes: make([]*Tree, 0),
	}
}

// AllChildNodes 返回node下所有子节点,包含本身
func (tree *Tree) AllChildNodes(node TreeNode) []TreeNode {
	treeNode := tree.find(node)
	if treeNode == nil {
		return []TreeNode{}
	}
	return tree.allChild(treeNode)
}

// find 查询node所在的tree,并且返回
func (tree *Tree) find(node TreeNode) *Tree {
	var stack []*Tree
	stack = append(stack, tree)
	var find *Tree
	for {
		if len(stack) == 0 {
			break
		}
		pop := stack[0]
		stack = stack[1:]
		stack = append(stack, pop.Nodes...)
		if pop == nil || pop.Node == nil {
			continue
		}
		if pop.Node.ID() == node.ID() {
			find = pop
			break
		}
	}
	return find
}

// allChild 返回treeNode下所有子节点
func (tree *Tree) allChild(treeNode *Tree) []TreeNode {
	var stack []*Tree
	stack = append(stack, treeNode)
	var res []TreeNode
	for {
		if len(stack) == 0 {
			break
		}
		pop := stack[0]
		stack = stack[1:]
		stack = append(stack, pop.Nodes...)
		res = append(res, pop.Node)
	}
	return res
}

// traverse  遍历节点
//
// tree 当前树
// node 判断的节点
func traverse(tree *Tree, node TreeNode) bool {
	list := tree.Nodes
	var match bool = false
	for i := range list {
		id, pid := list[i].Node.ID(), node.PID() //list[i].Node.PID() == node.ID()
		if pid == id {
			list[i].Nodes = append(list[i].Nodes, NewTree(node))
			return true
		}
		if match || traverse(list[i], node) {
			match = true
			break
		}
	}
	return match
}

// 返回tree下的所有子部门 (如果节点是组织,跳过)
func (tree *Tree) AllSubDepartment(node TreeNode) []TreeNode {
	treeNode := tree.find(node)
	if treeNode == nil {
		return []TreeNode{}
	}
	var stack []*Tree
	stack = append(stack, treeNode)
	var res []TreeNode
	rootId := treeNode.Node.(*Org).OrgId
	for {
		if len(stack) == 0 {
			break
		}
		pop := stack[0]
		stack = stack[1:]
		/***特殊处理***/
		if org, ok := pop.Node.(*Org); ok && org.OrgId != int64(rootId) {
			if org.IsOrg == IsOrgFlag {
				continue
			}
		}
		/***特殊处理***/
		stack = append(stack, pop.Nodes...)
		res = append(res, pop.Node)
	}
	return res
}

// Int64String  1 -> "1"  1->1
type Int64String int64

func (t Int64String) MarshalJSON() ([]byte, error) {
	stamp := fmt.Sprintf(`"%d"`, t)
	return []byte(stamp), nil
}

func (t *Int64String) UnmarshalJSON(data []byte) error {
	v, err := strconv.ParseInt(string(data), 10, 64)
	*t = Int64String(v)
	return err
}

// Int64String  1 -> "1"  "1"->1
type StringInt64 int64

func (t StringInt64) MarshalJSON() ([]byte, error) {
	stamp := fmt.Sprintf(`"%d"`, t)
	return []byte(stamp), nil
}

func (t *StringInt64) UnmarshalJSON(data []byte) error {
	if len(data) < 2 {
		*t = 0
		return fmt.Errorf("字符数字格式有误:" + string(data))
	}
	data = data[1 : len(data)-1]
	v, err := strconv.ParseInt(string(data), 10, 64)
	*t = StringInt64(v)
	return err
}