cbc.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package xmlenc
  2. import (
  3. "crypto/aes"
  4. "crypto/cipher"
  5. "crypto/des" // nolint: gas
  6. "encoding/base64"
  7. "errors"
  8. "fmt"
  9. "github.com/beevik/etree"
  10. )
  11. // CBC implements Decrypter and Encrypter for block ciphers in CBC mode
  12. type CBC struct {
  13. keySize int
  14. algorithm string
  15. cipher func([]byte) (cipher.Block, error)
  16. }
  17. // KeySize returns the length of the key required.
  18. func (e CBC) KeySize() int {
  19. return e.keySize
  20. }
  21. // Algorithm returns the name of the algorithm, as will be found
  22. // in an xenc:EncryptionMethod element.
  23. func (e CBC) Algorithm() string {
  24. return e.algorithm
  25. }
  26. // Encrypt encrypts plaintext with key, which should be a []byte of length KeySize().
  27. // It returns an xenc:EncryptedData element.
  28. func (e CBC) Encrypt(key interface{}, plaintext []byte) (*etree.Element, error) {
  29. keyBuf, ok := key.([]byte)
  30. if !ok {
  31. return nil, ErrIncorrectKeyType("[]byte")
  32. }
  33. if len(keyBuf) != e.keySize {
  34. return nil, ErrIncorrectKeyLength(e.keySize)
  35. }
  36. block, err := e.cipher(keyBuf)
  37. if err != nil {
  38. return nil, err
  39. }
  40. encryptedDataEl := etree.NewElement("xenc:EncryptedData")
  41. encryptedDataEl.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
  42. {
  43. randBuf := make([]byte, 16)
  44. if _, err := RandReader.Read(randBuf); err != nil {
  45. return nil, err
  46. }
  47. encryptedDataEl.CreateAttr("Id", fmt.Sprintf("_%x", randBuf))
  48. }
  49. em := encryptedDataEl.CreateElement("xenc:EncryptionMethod")
  50. em.CreateAttr("Algorithm", e.algorithm)
  51. em.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
  52. plaintext = appendPadding(plaintext, block.BlockSize())
  53. iv := make([]byte, block.BlockSize())
  54. if _, err := RandReader.Read(iv); err != nil {
  55. return nil, err
  56. }
  57. mode := cipher.NewCBCEncrypter(block, iv)
  58. ciphertext := make([]byte, len(plaintext))
  59. mode.CryptBlocks(ciphertext, plaintext)
  60. ciphertext = append(iv, ciphertext...)
  61. cd := encryptedDataEl.CreateElement("xenc:CipherData")
  62. cd.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
  63. cd.CreateElement("xenc:CipherValue").SetText(base64.StdEncoding.EncodeToString(ciphertext))
  64. return encryptedDataEl, nil
  65. }
  66. // Decrypt decrypts an encrypted element with key. If the ciphertext contains an
  67. // EncryptedKey element, then the type of `key` is determined by the registered
  68. // Decryptor for the EncryptedKey element. Otherwise, `key` must be a []byte of
  69. // length KeySize().
  70. func (e CBC) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, error) {
  71. // If the key is encrypted, decrypt it.
  72. if encryptedKeyEl := ciphertextEl.FindElement("./KeyInfo/EncryptedKey"); encryptedKeyEl != nil {
  73. var err error
  74. key, err = Decrypt(key, encryptedKeyEl)
  75. if err != nil {
  76. return nil, err
  77. }
  78. }
  79. keyBuf, ok := key.([]byte)
  80. if !ok {
  81. return nil, ErrIncorrectKeyType("[]byte")
  82. }
  83. if len(keyBuf) != e.KeySize() {
  84. return nil, ErrIncorrectKeyLength(e.KeySize())
  85. }
  86. block, err := e.cipher(keyBuf)
  87. if err != nil {
  88. return nil, err
  89. }
  90. ciphertext, err := getCiphertext(ciphertextEl)
  91. if err != nil {
  92. return nil, err
  93. }
  94. if len(ciphertext) < block.BlockSize() {
  95. return nil, errors.New("ciphertext too short")
  96. }
  97. iv := ciphertext[:aes.BlockSize]
  98. ciphertext = ciphertext[aes.BlockSize:]
  99. mode := cipher.NewCBCDecrypter(block, iv)
  100. plaintext := make([]byte, len(ciphertext))
  101. mode.CryptBlocks(plaintext, ciphertext) // decrypt in place
  102. plaintext, err = stripPadding(plaintext)
  103. if err != nil {
  104. return nil, err
  105. }
  106. return plaintext, nil
  107. }
  108. var (
  109. // AES128CBC implements AES128-CBC symetric key mode for encryption and decryption
  110. AES128CBC BlockCipher = CBC{
  111. keySize: 16,
  112. algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc",
  113. cipher: aes.NewCipher,
  114. }
  115. // AES192CBC implements AES192-CBC symetric key mode for encryption and decryption
  116. AES192CBC BlockCipher = CBC{
  117. keySize: 24,
  118. algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc",
  119. cipher: aes.NewCipher,
  120. }
  121. // AES256CBC implements AES256-CBC symetric key mode for encryption and decryption
  122. AES256CBC BlockCipher = CBC{
  123. keySize: 32,
  124. algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc",
  125. cipher: aes.NewCipher,
  126. }
  127. // TripleDES implements 3DES in CBC mode for encryption and decryption
  128. TripleDES BlockCipher = CBC{
  129. keySize: 8,
  130. algorithm: "http://www.w3.org/2001/04/xmlenc#tripledes-cbc",
  131. cipher: des.NewCipher,
  132. }
  133. )
  134. func init() {
  135. RegisterDecrypter(AES128CBC)
  136. RegisterDecrypter(AES192CBC)
  137. RegisterDecrypter(AES256CBC)
  138. RegisterDecrypter(TripleDES)
  139. }
  140. func appendPadding(buf []byte, blockSize int) []byte {
  141. paddingBytes := blockSize - (len(buf) % blockSize)
  142. padding := make([]byte, paddingBytes)
  143. padding[len(padding)-1] = byte(paddingBytes)
  144. return append(buf, padding...)
  145. }
  146. func stripPadding(buf []byte) ([]byte, error) {
  147. if len(buf) < 1 {
  148. return nil, errors.New("buffer is too short for padding")
  149. }
  150. paddingBytes := int(buf[len(buf)-1])
  151. if paddingBytes > len(buf)-1 {
  152. return nil, errors.New("buffer is too short for padding")
  153. }
  154. if paddingBytes < 1 {
  155. return nil, errors.New("padding must be at least one byte")
  156. }
  157. return buf[:len(buf)-paddingBytes], nil
  158. }