pubkey.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package xmlenc
  2. import (
  3. "crypto/rsa"
  4. "crypto/x509"
  5. "encoding/base64"
  6. "fmt"
  7. "github.com/beevik/etree"
  8. )
  9. // RSA implements Encrypter and Decrypter using RSA public key encryption.
  10. //
  11. // Use function like OAEP(), or PKCS1v15() to get an instance of this type ready
  12. // to use.
  13. type RSA struct {
  14. BlockCipher BlockCipher
  15. DigestMethod DigestMethod // only for OAEP
  16. algorithm string
  17. keyEncrypter func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error)
  18. keyDecrypter func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error)
  19. }
  20. // Algorithm returns the name of the algorithm
  21. func (e RSA) Algorithm() string {
  22. return e.algorithm
  23. }
  24. // Encrypt implements encrypter. certificate must be a []byte containing the ASN.1 bytes
  25. // of certificate containing an RSA public key.
  26. func (e RSA) Encrypt(certificate interface{}, plaintext []byte) (*etree.Element, error) {
  27. cert, ok := certificate.(*x509.Certificate)
  28. if !ok {
  29. return nil, ErrIncorrectKeyType("*x.509 certificate")
  30. }
  31. pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
  32. if !ok {
  33. return nil, ErrIncorrectKeyType("x.509 certificate with an RSA public key")
  34. }
  35. // generate a key
  36. key := make([]byte, e.BlockCipher.KeySize())
  37. if _, err := RandReader.Read(key); err != nil {
  38. return nil, err
  39. }
  40. keyInfoEl := etree.NewElement("ds:KeyInfo")
  41. keyInfoEl.CreateAttr("xmlns:ds", "http://www.w3.org/2000/09/xmldsig#")
  42. encryptedKey := keyInfoEl.CreateElement("xenc:EncryptedKey")
  43. {
  44. randBuf := make([]byte, 16)
  45. if _, err := RandReader.Read(randBuf); err != nil {
  46. return nil, err
  47. }
  48. encryptedKey.CreateAttr("Id", fmt.Sprintf("_%x", randBuf))
  49. }
  50. encryptedKey.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
  51. encryptionMethodEl := encryptedKey.CreateElement("xenc:EncryptionMethod")
  52. encryptionMethodEl.CreateAttr("Algorithm", e.algorithm)
  53. encryptionMethodEl.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
  54. if e.DigestMethod != nil {
  55. dm := encryptionMethodEl.CreateElement("ds:DigestMethod")
  56. dm.CreateAttr("Algorithm", e.DigestMethod.Algorithm())
  57. dm.CreateAttr("xmlns:ds", "http://www.w3.org/2000/09/xmldsig#")
  58. }
  59. {
  60. innerKeyInfoEl := encryptedKey.CreateElement("ds:KeyInfo")
  61. x509data := innerKeyInfoEl.CreateElement("ds:X509Data")
  62. x509data.CreateElement("ds:X509Certificate").SetText(
  63. base64.StdEncoding.EncodeToString(cert.Raw),
  64. )
  65. }
  66. buf, err := e.keyEncrypter(e, pubKey, key)
  67. if err != nil {
  68. return nil, err
  69. }
  70. cd := encryptedKey.CreateElement("xenc:CipherData")
  71. cd.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
  72. cd.CreateElement("xenc:CipherValue").SetText(base64.StdEncoding.EncodeToString(buf))
  73. encryptedDataEl, err := e.BlockCipher.Encrypt(key, plaintext)
  74. if err != nil {
  75. return nil, err
  76. }
  77. encryptedDataEl.InsertChild(encryptedDataEl.FindElement("./CipherData"), keyInfoEl)
  78. return encryptedDataEl, nil
  79. }
  80. // Decrypt implements Decryptor. `key` must be an *rsa.PrivateKey.
  81. func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, error) {
  82. rsaKey, err := validateRSAKey(key, ciphertextEl)
  83. if err != nil {
  84. return nil, err
  85. }
  86. ciphertext, err := getCiphertext(ciphertextEl)
  87. if err != nil {
  88. return nil, err
  89. }
  90. {
  91. digestMethodEl := ciphertextEl.FindElement("./EncryptionMethod/DigestMethod")
  92. if digestMethodEl == nil {
  93. e.DigestMethod = SHA1
  94. } else {
  95. hashAlgorithmStr := digestMethodEl.SelectAttrValue("Algorithm", "")
  96. digestMethod, ok := digestMethods[hashAlgorithmStr]
  97. if !ok {
  98. return nil, ErrAlgorithmNotImplemented(hashAlgorithmStr)
  99. }
  100. e.DigestMethod = digestMethod
  101. }
  102. }
  103. return e.keyDecrypter(e, rsaKey, ciphertext)
  104. }
  105. // OAEP returns a version of RSA that implements RSA in OAEP-MGF1P mode. By default
  106. // the block cipher used is AES-256 CBC and the digest method is SHA-256. You can
  107. // specify other ciphers and digest methods by assigning to BlockCipher or
  108. // DigestMethod.
  109. func OAEP() RSA {
  110. return RSA{
  111. BlockCipher: AES256CBC,
  112. DigestMethod: SHA256,
  113. algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p",
  114. keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
  115. return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil)
  116. },
  117. keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
  118. return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil)
  119. },
  120. }
  121. }
  122. // PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default
  123. // the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15
  124. // does not use a digest function.
  125. func PKCS1v15() RSA {
  126. return RSA{
  127. BlockCipher: AES256CBC,
  128. DigestMethod: nil,
  129. algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-1_5",
  130. keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
  131. return rsa.EncryptPKCS1v15(RandReader, pubKey, plaintext)
  132. },
  133. keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
  134. return rsa.DecryptPKCS1v15(RandReader, privKey, ciphertext)
  135. },
  136. }
  137. }
  138. func init() {
  139. RegisterDecrypter(OAEP())
  140. RegisterDecrypter(PKCS1v15())
  141. }