package licence import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "io" "io/ioutil" "math/big" "os" ) const ( privateFileName = "private.key" publicFileName = "public.pem" privateKeyPrefix = "BASIC AIOTLINK PRIVATE KEY " publicKeyPrefix = "BASIC AIOTLINK PUBLIC KEY " ) type RSASecurity struct { pubStr []byte //公钥字符串 priStr []byte //私钥字符串 pubkey *rsa.PublicKey //公钥 prikey *rsa.PrivateKey //私钥 } func GetRsaKey(prefix string) error { privateKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { return err } x509PrivateKey := x509.MarshalPKCS1PrivateKey(privateKey) privateFile, err := os.Create(prefix + "-" + privateFileName) if err != nil { return err } defer privateFile.Close() privateBlock := pem.Block{ Type: privateKeyPrefix, Bytes: x509PrivateKey, } if err = pem.Encode(privateFile, &privateBlock); err != nil { return err } publicKey := privateKey.PublicKey x509PublicKey, err := x509.MarshalPKIXPublicKey(&publicKey) if err != nil { panic(err) } publicFile, _ := os.Create(prefix + "-" + publicFileName) defer publicFile.Close() publicBlock := pem.Block{ Type: publicKeyPrefix, Bytes: x509PublicKey, } if err = pem.Encode(publicFile, &publicBlock); err != nil { return err } return nil } // 设置公钥 func (rsas *RSASecurity) SetPublicKey(pubStr []byte) (err error) { rsas.pubStr = pubStr rsas.pubkey, err = rsas.GetPublickey() return err } // 设置私钥 func (rsas *RSASecurity) SetPrivateKey(priStr []byte) (err error) { rsas.priStr = priStr rsas.prikey, err = rsas.GetPrivatekey() return err } // *rsa.PublicKey func (rsas *RSASecurity) GetPrivatekey() (*rsa.PrivateKey, error) { return getPriKey(rsas.priStr) } // *rsa.PrivateKey func (rsas *RSASecurity) GetPublickey() (*rsa.PublicKey, error) { return getPubKey(rsas.pubStr) } // 公钥加密 func (rsas *RSASecurity) PubKeyENCTYPT(input []byte) ([]byte, error) { if rsas.pubkey == nil { return []byte(""), errors.New(`Please set the public key in advance`) } output := bytes.NewBuffer(nil) err := pubKeyIO(rsas.pubkey, bytes.NewReader(input), output, true) if err != nil { return []byte(""), err } return ioutil.ReadAll(output) } // 公钥解密 func (rsas *RSASecurity) PubKeyDECRYPT(input []byte) ([]byte, error) { if rsas.pubkey == nil { return []byte(""), errors.New(`Please set the public key in advance`) } output := bytes.NewBuffer(nil) err := pubKeyIO(rsas.pubkey, bytes.NewReader(input), output, false) if err != nil { return []byte(""), err } return ioutil.ReadAll(output) } // 私钥加密 func (rsas *RSASecurity) PriKeyENCTYPT(input []byte) ([]byte, error) { if rsas.prikey == nil { return []byte(""), errors.New(`Please set the private key in advance`) } output := bytes.NewBuffer(nil) err := priKeyIO(rsas.prikey, bytes.NewReader(input), output, true) if err != nil { return []byte(""), err } return ioutil.ReadAll(output) } // 私钥解密 func (rsas *RSASecurity) PriKeyDECRYPT(input []byte) ([]byte, error) { if rsas.prikey == nil { return []byte(""), errors.New(`Please set the private key in advance`) } output := bytes.NewBuffer(nil) err := priKeyIO(rsas.prikey, bytes.NewReader(input), output, false) if err != nil { return []byte(""), err } return ioutil.ReadAll(output) } var ( ErrDataToLarge = errors.New("message too long for RSA public key size") ErrDataLen = errors.New("data length error") ErrDataBroken = errors.New("data broken, first byte is not zero") ErrKeyPairDismatch = errors.New("data is not encrypted by the private key") ErrDecryption = errors.New("decryption error") ErrPublicKey = errors.New("get public key error") ErrPrivateKey = errors.New("get private key error") ) // 设置公钥 func getPubKey(publickey []byte) (*rsa.PublicKey, error) { // decode public key block, _ := pem.Decode(publickey) if block == nil { return nil, errors.New("get public key error") } // x509 parse public key pub, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return nil, err } return pub.(*rsa.PublicKey), err } // 设置私钥 func getPriKey(privatekey []byte) (*rsa.PrivateKey, error) { block, _ := pem.Decode(privatekey) if block == nil { return nil, errors.New("get private key error") } pri, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err == nil { return pri, nil } pri2, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, err } return pri2.(*rsa.PrivateKey), nil } // 公钥加密或解密byte func pubKeyByte(pub *rsa.PublicKey, in []byte, isEncrytp bool) ([]byte, error) { k := (pub.N.BitLen() + 7) / 8 if isEncrytp { k = k - 11 } if len(in) <= k { if isEncrytp { return rsa.EncryptPKCS1v15(rand.Reader, pub, in) } else { return pubKeyDecrypt(pub, in) } } else { iv := make([]byte, k) out := bytes.NewBuffer(iv) if err := pubKeyIO(pub, bytes.NewReader(in), out, isEncrytp); err != nil { return nil, err } return ioutil.ReadAll(out) } } // 私钥加密或解密byte func priKeyByte(pri *rsa.PrivateKey, in []byte, isEncrytp bool) ([]byte, error) { k := (pri.N.BitLen() + 7) / 8 if isEncrytp { k = k - 11 } if len(in) <= k { if isEncrytp { return priKeyEncrypt(rand.Reader, pri, in) } else { return rsa.DecryptPKCS1v15(rand.Reader, pri, in) } } else { iv := make([]byte, k) out := bytes.NewBuffer(iv) if err := priKeyIO(pri, bytes.NewReader(in), out, isEncrytp); err != nil { return nil, err } return ioutil.ReadAll(out) } } // 公钥加密或解密Reader func pubKeyIO(pub *rsa.PublicKey, in io.Reader, out io.Writer, isEncrytp bool) (err error) { k := (pub.N.BitLen() + 7) / 8 if isEncrytp { k = k - 11 } buf := make([]byte, k) var b []byte size := 0 for { size, err = in.Read(buf) if err != nil { if err == io.EOF { return nil } return err } if size < k { b = buf[:size] } else { b = buf } if isEncrytp { b, err = rsa.EncryptPKCS1v15(rand.Reader, pub, b) } else { b, err = pubKeyDecrypt(pub, b) } if err != nil { return err } if _, err = out.Write(b); err != nil { return err } } return nil } // 私钥加密或解密Reader func priKeyIO(pri *rsa.PrivateKey, r io.Reader, w io.Writer, isEncrytp bool) (err error) { k := (pri.N.BitLen() + 7) / 8 if isEncrytp { k = k - 11 } buf := make([]byte, k) var b []byte size := 0 for { size, err = r.Read(buf) if err != nil { if err == io.EOF { return nil } return err } if size < k { b = buf[:size] } else { b = buf } if isEncrytp { b, err = priKeyEncrypt(rand.Reader, pri, b) } else { b, err = rsa.DecryptPKCS1v15(rand.Reader, pri, b) } if err != nil { return err } if _, err = w.Write(b); err != nil { return err } } return nil } // 公钥解密 func pubKeyDecrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) { k := (pub.N.BitLen() + 7) / 8 if k != len(data) { return nil, ErrDataLen } m := new(big.Int).SetBytes(data) if m.Cmp(pub.N) > 0 { return nil, ErrDataToLarge } m.Exp(m, big.NewInt(int64(pub.E)), pub.N) d := leftPad(m.Bytes(), k) if d[0] != 0 { return nil, ErrDataBroken } if d[1] != 0 && d[1] != 1 { return nil, ErrKeyPairDismatch } var i = 2 for ; i < len(d); i++ { if d[i] == 0 { break } } i++ if i == len(d) { return nil, nil } return d[i:], nil } // 私钥加密 func priKeyEncrypt(rand io.Reader, priv *rsa.PrivateKey, hashed []byte) ([]byte, error) { tLen := len(hashed) k := (priv.N.BitLen() + 7) / 8 if k < tLen+11 { return nil, ErrDataLen } em := make([]byte, k) em[1] = 1 for i := 2; i < k-tLen-1; i++ { em[i] = 0xff } copy(em[k-tLen:k], hashed) m := new(big.Int).SetBytes(em) c, err := decrypt(rand, priv, m) if err != nil { return nil, err } copyWithLeftPad(em, c.Bytes()) return em, nil } // 从crypto/rsa复制 var bigZero = big.NewInt(0) var bigOne = big.NewInt(1) // 从crypto/rsa复制 func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int { e := big.NewInt(int64(pub.E)) c.Exp(m, e, pub.N) return c } // 从crypto/rsa复制 func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) { if c.Cmp(priv.N) > 0 { err = ErrDecryption return } var ir *big.Int if random != nil { var r *big.Int for { r, err = rand.Int(random, priv.N) if err != nil { return } if r.Cmp(bigZero) == 0 { r = bigOne } var ok bool ir, ok = modInverse(r, priv.N) if ok { break } } bigE := big.NewInt(int64(priv.E)) rpowe := new(big.Int).Exp(r, bigE, priv.N) cCopy := new(big.Int).Set(c) cCopy.Mul(cCopy, rpowe) cCopy.Mod(cCopy, priv.N) c = cCopy } if priv.Precomputed.Dp == nil { m = new(big.Int).Exp(c, priv.D, priv.N) } else { m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) m.Sub(m, m2) if m.Sign() < 0 { m.Add(m, priv.Primes[0]) } m.Mul(m, priv.Precomputed.Qinv) m.Mod(m, priv.Primes[0]) m.Mul(m, priv.Primes[1]) m.Add(m, m2) for i, values := range priv.Precomputed.CRTValues { prime := priv.Primes[2+i] m2.Exp(c, values.Exp, prime) m2.Sub(m2, m) m2.Mul(m2, values.Coeff) m2.Mod(m2, prime) if m2.Sign() < 0 { m2.Add(m2, prime) } m2.Mul(m2, values.R) m.Add(m, m2) } } if ir != nil { m.Mul(m, ir) m.Mod(m, priv.N) } return } // 从crypto/rsa复制 func copyWithLeftPad(dest, src []byte) { numPaddingBytes := len(dest) - len(src) for i := 0; i < numPaddingBytes; i++ { dest[i] = 0 } copy(dest[numPaddingBytes:], src) } // 从crypto/rsa复制 func nonZeroRandomBytes(s []byte, rand io.Reader) (err error) { _, err = io.ReadFull(rand, s) if err != nil { return } for i := 0; i < len(s); i++ { for s[i] == 0 { _, err = io.ReadFull(rand, s[i:i+1]) if err != nil { return } s[i] ^= 0x42 } } return } // 从crypto/rsa复制 func leftPad(input []byte, size int) (out []byte) { n := len(input) if n > size { n = size } out = make([]byte, size) copy(out[len(out)-n:], input) return } // 从crypto/rsa复制 func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { g := new(big.Int) x := new(big.Int) y := new(big.Int) g.GCD(x, y, a, n) if g.Cmp(bigOne) != 0 { return } if x.Cmp(bigOne) < 0 { x.Add(x, n) } return x, true }