download_test.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. package s3manager_test
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io/ioutil"
  6. "net/http"
  7. "regexp"
  8. "strconv"
  9. "sync"
  10. "testing"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/aws/request"
  14. "github.com/aws/aws-sdk-go/awstesting/unit"
  15. "github.com/aws/aws-sdk-go/service/s3"
  16. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  17. )
  18. func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
  19. var m sync.Mutex
  20. names := []string{}
  21. ranges := []string{}
  22. svc := s3.New(unit.Session)
  23. svc.Handlers.Send.Clear()
  24. svc.Handlers.Send.PushBack(func(r *request.Request) {
  25. m.Lock()
  26. defer m.Unlock()
  27. names = append(names, r.Operation.Name)
  28. ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
  29. rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
  30. rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
  31. start, _ := strconv.ParseInt(rng[1], 10, 64)
  32. fin, _ := strconv.ParseInt(rng[2], 10, 64)
  33. fin++
  34. if fin > int64(len(data)) {
  35. fin = int64(len(data))
  36. }
  37. bodyBytes := data[start:fin]
  38. r.HTTPResponse = &http.Response{
  39. StatusCode: 200,
  40. Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
  41. Header: http.Header{},
  42. }
  43. r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
  44. start, fin-1, len(data)))
  45. r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
  46. })
  47. return svc, &names, &ranges
  48. }
  49. func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) {
  50. var m sync.Mutex
  51. names := []string{}
  52. svc := s3.New(unit.Session)
  53. svc.Handlers.Send.Clear()
  54. svc.Handlers.Send.PushBack(func(r *request.Request) {
  55. m.Lock()
  56. defer m.Unlock()
  57. names = append(names, r.Operation.Name)
  58. r.HTTPResponse = &http.Response{
  59. StatusCode: 200,
  60. Body: ioutil.NopCloser(bytes.NewReader(data[:])),
  61. Header: http.Header{},
  62. }
  63. r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(data)))
  64. })
  65. return svc, &names
  66. }
  67. func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) {
  68. var m sync.Mutex
  69. names := []string{}
  70. var index int = 0
  71. svc := s3.New(unit.Session)
  72. svc.Handlers.Send.Clear()
  73. svc.Handlers.Send.PushBack(func(r *request.Request) {
  74. m.Lock()
  75. defer m.Unlock()
  76. names = append(names, r.Operation.Name)
  77. r.HTTPResponse = &http.Response{
  78. StatusCode: states[index],
  79. Body: ioutil.NopCloser(bytes.NewReader(data[:])),
  80. Header: http.Header{},
  81. }
  82. index++
  83. })
  84. return svc, &names
  85. }
  86. func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) {
  87. var m sync.Mutex
  88. names := []string{}
  89. ranges := []string{}
  90. var index int = 0
  91. svc := s3.New(unit.Session)
  92. svc.Handlers.Send.Clear()
  93. svc.Handlers.Send.PushBack(func(r *request.Request) {
  94. m.Lock()
  95. defer m.Unlock()
  96. names = append(names, r.Operation.Name)
  97. ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
  98. rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
  99. rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
  100. start, _ := strconv.ParseInt(rng[1], 10, 64)
  101. fin, _ := strconv.ParseInt(rng[2], 10, 64)
  102. fin++
  103. if fin >= int64(len(data)) {
  104. fin = int64(len(data))
  105. }
  106. // Setting start and finish to 0 because this state of 1 is suppose to
  107. // be an error state of 416
  108. if index == len(states)-1 {
  109. start = 0
  110. fin = 0
  111. }
  112. bodyBytes := data[start:fin]
  113. r.HTTPResponse = &http.Response{
  114. StatusCode: states[index],
  115. Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
  116. Header: http.Header{},
  117. }
  118. r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*",
  119. start, fin-1))
  120. index++
  121. })
  122. return svc, &names
  123. }
  124. func TestDownloadOrder(t *testing.T) {
  125. s, names, ranges := dlLoggingSvc(buf12MB)
  126. d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
  127. d.Concurrency = 1
  128. })
  129. w := &aws.WriteAtBuffer{}
  130. n, err := d.Download(w, &s3.GetObjectInput{
  131. Bucket: aws.String("bucket"),
  132. Key: aws.String("key"),
  133. })
  134. assert.Nil(t, err)
  135. assert.Equal(t, int64(len(buf12MB)), n)
  136. assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
  137. assert.Equal(t, []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}, *ranges)
  138. count := 0
  139. for _, b := range w.Bytes() {
  140. count += int(b)
  141. }
  142. assert.Equal(t, 0, count)
  143. }
  144. func TestDownloadZero(t *testing.T) {
  145. s, names, ranges := dlLoggingSvc([]byte{})
  146. d := s3manager.NewDownloaderWithClient(s)
  147. w := &aws.WriteAtBuffer{}
  148. n, err := d.Download(w, &s3.GetObjectInput{
  149. Bucket: aws.String("bucket"),
  150. Key: aws.String("key"),
  151. })
  152. assert.Nil(t, err)
  153. assert.Equal(t, int64(0), n)
  154. assert.Equal(t, []string{"GetObject"}, *names)
  155. assert.Equal(t, []string{"bytes=0-5242879"}, *ranges)
  156. }
  157. func TestDownloadSetPartSize(t *testing.T) {
  158. s, names, ranges := dlLoggingSvc([]byte{1, 2, 3})
  159. d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
  160. d.Concurrency = 1
  161. d.PartSize = 1
  162. })
  163. w := &aws.WriteAtBuffer{}
  164. n, err := d.Download(w, &s3.GetObjectInput{
  165. Bucket: aws.String("bucket"),
  166. Key: aws.String("key"),
  167. })
  168. assert.Nil(t, err)
  169. assert.Equal(t, int64(3), n)
  170. assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
  171. assert.Equal(t, []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}, *ranges)
  172. assert.Equal(t, []byte{1, 2, 3}, w.Bytes())
  173. }
  174. func TestDownloadError(t *testing.T) {
  175. s, names, _ := dlLoggingSvc([]byte{1, 2, 3})
  176. num := 0
  177. s.Handlers.Send.PushBack(func(r *request.Request) {
  178. num++
  179. if num > 1 {
  180. r.HTTPResponse.StatusCode = 400
  181. r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
  182. }
  183. })
  184. d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
  185. d.Concurrency = 1
  186. d.PartSize = 1
  187. })
  188. w := &aws.WriteAtBuffer{}
  189. n, err := d.Download(w, &s3.GetObjectInput{
  190. Bucket: aws.String("bucket"),
  191. Key: aws.String("key"),
  192. })
  193. assert.NotNil(t, err)
  194. assert.Equal(t, int64(1), n)
  195. assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
  196. assert.Equal(t, []byte{1}, w.Bytes())
  197. }
  198. func TestDownloadNonChunk(t *testing.T) {
  199. s, names := dlLoggingSvcNoChunk(buf2MB)
  200. d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
  201. d.Concurrency = 1
  202. })
  203. w := &aws.WriteAtBuffer{}
  204. n, err := d.Download(w, &s3.GetObjectInput{
  205. Bucket: aws.String("bucket"),
  206. Key: aws.String("key"),
  207. })
  208. assert.Nil(t, err)
  209. assert.Equal(t, int64(len(buf2MB)), n)
  210. assert.Equal(t, []string{"GetObject"}, *names)
  211. count := 0
  212. for _, b := range w.Bytes() {
  213. count += int(b)
  214. }
  215. assert.Equal(t, 0, count)
  216. }
  217. func TestDownloadNoContentRangeLength(t *testing.T) {
  218. s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416})
  219. d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
  220. d.Concurrency = 1
  221. })
  222. w := &aws.WriteAtBuffer{}
  223. n, err := d.Download(w, &s3.GetObjectInput{
  224. Bucket: aws.String("bucket"),
  225. Key: aws.String("key"),
  226. })
  227. assert.Nil(t, err)
  228. assert.Equal(t, int64(len(buf2MB)), n)
  229. assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
  230. count := 0
  231. for _, b := range w.Bytes() {
  232. count += int(b)
  233. }
  234. assert.Equal(t, 0, count)
  235. }
  236. func TestDownloadContentRangeTotalAny(t *testing.T) {
  237. s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416})
  238. d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
  239. d.Concurrency = 1
  240. })
  241. w := &aws.WriteAtBuffer{}
  242. n, err := d.Download(w, &s3.GetObjectInput{
  243. Bucket: aws.String("bucket"),
  244. Key: aws.String("key"),
  245. })
  246. assert.Nil(t, err)
  247. assert.Equal(t, int64(len(buf2MB)), n)
  248. assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
  249. count := 0
  250. for _, b := range w.Bytes() {
  251. count += int(b)
  252. }
  253. assert.Equal(t, 0, count)
  254. }