checksums.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package sqs
  2. import (
  3. "crypto/md5"
  4. "encoding/hex"
  5. "fmt"
  6. "strings"
  7. "github.com/aws/aws-sdk-go/aws"
  8. "github.com/aws/aws-sdk-go/aws/awserr"
  9. "github.com/aws/aws-sdk-go/aws/request"
  10. )
  11. var (
  12. errChecksumMissingBody = fmt.Errorf("cannot compute checksum. missing body")
  13. errChecksumMissingMD5 = fmt.Errorf("cannot verify checksum. missing response MD5")
  14. )
  15. func setupChecksumValidation(r *request.Request) {
  16. if aws.BoolValue(r.Config.DisableComputeChecksums) {
  17. return
  18. }
  19. switch r.Operation.Name {
  20. case opSendMessage:
  21. r.Handlers.Unmarshal.PushBack(verifySendMessage)
  22. case opSendMessageBatch:
  23. r.Handlers.Unmarshal.PushBack(verifySendMessageBatch)
  24. case opReceiveMessage:
  25. r.Handlers.Unmarshal.PushBack(verifyReceiveMessage)
  26. }
  27. }
  28. func verifySendMessage(r *request.Request) {
  29. if r.DataFilled() && r.ParamsFilled() {
  30. in := r.Params.(*SendMessageInput)
  31. out := r.Data.(*SendMessageOutput)
  32. err := checksumsMatch(in.MessageBody, out.MD5OfMessageBody)
  33. if err != nil {
  34. setChecksumError(r, err.Error())
  35. }
  36. }
  37. }
  38. func verifySendMessageBatch(r *request.Request) {
  39. if r.DataFilled() && r.ParamsFilled() {
  40. entries := map[string]*SendMessageBatchResultEntry{}
  41. ids := []string{}
  42. out := r.Data.(*SendMessageBatchOutput)
  43. for _, entry := range out.Successful {
  44. entries[*entry.Id] = entry
  45. }
  46. in := r.Params.(*SendMessageBatchInput)
  47. for _, entry := range in.Entries {
  48. if e := entries[*entry.Id]; e != nil {
  49. err := checksumsMatch(entry.MessageBody, e.MD5OfMessageBody)
  50. if err != nil {
  51. ids = append(ids, *e.MessageId)
  52. }
  53. }
  54. }
  55. if len(ids) > 0 {
  56. setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
  57. }
  58. }
  59. }
  60. func verifyReceiveMessage(r *request.Request) {
  61. if r.DataFilled() && r.ParamsFilled() {
  62. ids := []string{}
  63. out := r.Data.(*ReceiveMessageOutput)
  64. for i, msg := range out.Messages {
  65. err := checksumsMatch(msg.Body, msg.MD5OfBody)
  66. if err != nil {
  67. if msg.MessageId == nil {
  68. if r.Config.Logger != nil {
  69. r.Config.Logger.Log(fmt.Sprintf(
  70. "WARN: SQS.ReceiveMessage failed checksum request id: %s, message %d has no message ID.",
  71. r.RequestID, i,
  72. ))
  73. }
  74. continue
  75. }
  76. ids = append(ids, *msg.MessageId)
  77. }
  78. }
  79. if len(ids) > 0 {
  80. setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
  81. }
  82. }
  83. }
  84. func checksumsMatch(body, expectedMD5 *string) error {
  85. if body == nil {
  86. return errChecksumMissingBody
  87. } else if expectedMD5 == nil {
  88. return errChecksumMissingMD5
  89. }
  90. msum := md5.Sum([]byte(*body))
  91. sum := hex.EncodeToString(msum[:])
  92. if sum != *expectedMD5 {
  93. return fmt.Errorf("expected MD5 checksum '%s', got '%s'", *expectedMD5, sum)
  94. }
  95. return nil
  96. }
  97. func setChecksumError(r *request.Request, format string, args ...interface{}) {
  98. r.Retryable = aws.Bool(true)
  99. r.Error = awserr.New("InvalidChecksum", fmt.Sprintf(format, args...), nil)
  100. }