package blockchain

import (
	"crypto"
	"crypto/md5"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/base64"
	"encoding/pem"
	"errors"
	"fmt"
)

// rsa签名
func RsaSign(publicKey []byte, origData []byte) ([]byte, error) {
	block, _ := pem.Decode(publicKey)
	if block == nil {
		return nil, errors.New("public key error")
	}
	pubInterface, err := x509.ParsePKCS8PrivateKey(block.Bytes)
	if err != nil {
		return nil, err
	}

	// md5
	hash := md5.New()
	hash.Write([]byte(origData))
	pub := pubInterface.(*rsa.PrivateKey)
	return rsa.SignPKCS1v15(rand.Reader, pub, crypto.MD5, hash.Sum(nil))
	//pub := pubInterface.(*rsa.PublicKey)
	//return rsa.EncryptPKCS1v15(rand.Reader, pub, origData)
}

func RsaEncrypt(publicKey []byte, origData []byte) ([]byte, error) {
	block, _ := pem.Decode(publicKey)
	if block == nil {
		return nil, errors.New("public key error")
	}
	pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
	if err != nil {
		return nil, err
	}

	fmt.Println(string(origData))
	// md5
	hash := md5.New()
	hash.Write([]byte(origData))
	pub := pubInterface.(*rsa.PublicKey)
	fmt.Println(hash.Sum(nil))
	return rsa.EncryptPKCS1v15(rand.Reader, pub, hash.Sum(nil))
}

// 解密
func RsaDecrypt(privateKey []byte, ciphertext []byte) ([]byte, error) {
	block, _ := pem.Decode(privateKey)
	if block == nil {
		return nil, errors.New("private key error!")
	}
	encryptData, _ := base64.StdEncoding.DecodeString(string(ciphertext))
	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
	if err != nil {
		// pkcs1 是标准但裸奔,pkcs8升级支持密码
		pri2, err := x509.ParsePKCS8PrivateKey(block.Bytes)
		if err != nil {
			return nil, err
		}
		priv = pri2.(*rsa.PrivateKey)
	}
	return rsa.DecryptPKCS1v15(rand.Reader, priv, encryptData)
}