download.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. package s3manager
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "github.com/aws/aws-sdk-go/aws/awserr"
  10. "github.com/aws/aws-sdk-go/aws/awsutil"
  11. "github.com/aws/aws-sdk-go/aws/client"
  12. "github.com/aws/aws-sdk-go/aws/request"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. "github.com/aws/aws-sdk-go/service/s3/s3iface"
  15. )
  16. // DefaultDownloadPartSize is the default range of bytes to get at a time when
  17. // using Download().
  18. const DefaultDownloadPartSize = 1024 * 1024 * 5
  19. // DefaultDownloadConcurrency is the default number of goroutines to spin up
  20. // when using Download().
  21. const DefaultDownloadConcurrency = 5
  22. // The Downloader structure that calls Download(). It is safe to call Download()
  23. // on this structure for multiple objects and across concurrent goroutines.
  24. // Mutating the Downloader's properties is not safe to be done concurrently.
  25. type Downloader struct {
  26. // The buffer size (in bytes) to use when buffering data into chunks and
  27. // sending them as parts to S3. The minimum allowed part size is 5MB, and
  28. // if this value is set to zero, the DefaultPartSize value will be used.
  29. PartSize int64
  30. // The number of goroutines to spin up in parallel when sending parts.
  31. // If this is set to zero, the DefaultDownloadConcurrency value will be used.
  32. Concurrency int
  33. // An S3 client to use when performing downloads.
  34. S3 s3iface.S3API
  35. }
  36. // NewDownloader creates a new Downloader instance to downloads objects from
  37. // S3 in concurrent chunks. Pass in additional functional options to customize
  38. // the downloader behavior. Requires a client.ConfigProvider in order to create
  39. // a S3 service client. The session.Session satisfies the client.ConfigProvider
  40. // interface.
  41. //
  42. // Example:
  43. // // The session the S3 Downloader will use
  44. // sess, err := session.NewSession()
  45. //
  46. // // Create a downloader with the session and default options
  47. // downloader := s3manager.NewDownloader(sess)
  48. //
  49. // // Create a downloader with the session and custom options
  50. // downloader := s3manager.NewDownloader(sess, func(d *s3manager.Uploader) {
  51. // d.PartSize = 64 * 1024 * 1024 // 64MB per part
  52. // })
  53. func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
  54. d := &Downloader{
  55. S3: s3.New(c),
  56. PartSize: DefaultDownloadPartSize,
  57. Concurrency: DefaultDownloadConcurrency,
  58. }
  59. for _, option := range options {
  60. option(d)
  61. }
  62. return d
  63. }
  64. // NewDownloaderWithClient creates a new Downloader instance to downloads
  65. // objects from S3 in concurrent chunks. Pass in additional functional
  66. // options to customize the downloader behavior. Requires a S3 service client
  67. // to make S3 API calls.
  68. //
  69. // Example:
  70. // // The session the S3 Downloader will use
  71. // sess, err := session.NewSession()
  72. //
  73. // // The S3 client the S3 Downloader will use
  74. // s3Svc := s3.new(sess)
  75. //
  76. // // Create a downloader with the s3 client and default options
  77. // downloader := s3manager.NewDownloaderWithClient(s3Svc)
  78. //
  79. // // Create a downloader with the s3 client and custom options
  80. // downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Uploader) {
  81. // d.PartSize = 64 * 1024 * 1024 // 64MB per part
  82. // })
  83. func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
  84. d := &Downloader{
  85. S3: svc,
  86. PartSize: DefaultDownloadPartSize,
  87. Concurrency: DefaultDownloadConcurrency,
  88. }
  89. for _, option := range options {
  90. option(d)
  91. }
  92. return d
  93. }
  94. // Download downloads an object in S3 and writes the payload into w using
  95. // concurrent GET requests.
  96. //
  97. // Additional functional options can be provided to configure the individual
  98. // upload. These options are copies of the Uploader instance Upload is called from.
  99. // Modifying the options will not impact the original Uploader instance.
  100. //
  101. // It is safe to call this method concurrently across goroutines.
  102. //
  103. // The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
  104. // downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
  105. func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
  106. impl := downloader{w: w, in: input, ctx: d}
  107. for _, option := range options {
  108. option(&impl.ctx)
  109. }
  110. return impl.download()
  111. }
  112. // downloader is the implementation structure used internally by Downloader.
  113. type downloader struct {
  114. ctx Downloader
  115. in *s3.GetObjectInput
  116. w io.WriterAt
  117. wg sync.WaitGroup
  118. m sync.Mutex
  119. pos int64
  120. totalBytes int64
  121. written int64
  122. err error
  123. }
  124. // init initializes the downloader with default options.
  125. func (d *downloader) init() {
  126. d.totalBytes = -1
  127. if d.ctx.Concurrency == 0 {
  128. d.ctx.Concurrency = DefaultDownloadConcurrency
  129. }
  130. if d.ctx.PartSize == 0 {
  131. d.ctx.PartSize = DefaultDownloadPartSize
  132. }
  133. }
  134. // download performs the implementation of the object download across ranged
  135. // GETs.
  136. func (d *downloader) download() (n int64, err error) {
  137. d.init()
  138. // Spin off first worker to check additional header information
  139. d.getChunk()
  140. if total := d.getTotalBytes(); total >= 0 {
  141. // Spin up workers
  142. ch := make(chan dlchunk, d.ctx.Concurrency)
  143. for i := 0; i < d.ctx.Concurrency; i++ {
  144. d.wg.Add(1)
  145. go d.downloadPart(ch)
  146. }
  147. // Assign work
  148. for d.getErr() == nil {
  149. if d.pos >= total {
  150. break // We're finished queuing chunks
  151. }
  152. // Queue the next range of bytes to read.
  153. ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
  154. d.pos += d.ctx.PartSize
  155. }
  156. // Wait for completion
  157. close(ch)
  158. d.wg.Wait()
  159. } else {
  160. // Checking if we read anything new
  161. for d.err == nil {
  162. d.getChunk()
  163. }
  164. // We expect a 416 error letting us know we are done downloading the
  165. // total bytes. Since we do not know the content's length, this will
  166. // keep grabbing chunks of data until the range of bytes specified in
  167. // the request is out of range of the content. Once, this happens, a
  168. // 416 should occur.
  169. e, ok := d.err.(awserr.RequestFailure)
  170. if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
  171. d.err = nil
  172. }
  173. }
  174. // Return error
  175. return d.written, d.err
  176. }
  177. // downloadPart is an individual goroutine worker reading from the ch channel
  178. // and performing a GetObject request on the data with a given byte range.
  179. //
  180. // If this is the first worker, this operation also resolves the total number
  181. // of bytes to be read so that the worker manager knows when it is finished.
  182. func (d *downloader) downloadPart(ch chan dlchunk) {
  183. defer d.wg.Done()
  184. for {
  185. chunk, ok := <-ch
  186. if !ok {
  187. break
  188. }
  189. d.downloadChunk(chunk)
  190. }
  191. }
  192. // getChunk grabs a chunk of data from the body.
  193. // Not thread safe. Should only used when grabbing data on a single thread.
  194. func (d *downloader) getChunk() {
  195. chunk := dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
  196. d.pos += d.ctx.PartSize
  197. d.downloadChunk(chunk)
  198. }
  199. // downloadChunk downloads the chunk froom s3
  200. func (d *downloader) downloadChunk(chunk dlchunk) {
  201. if d.getErr() != nil {
  202. return
  203. }
  204. // Get the next byte range of data
  205. in := &s3.GetObjectInput{}
  206. awsutil.Copy(in, d.in)
  207. rng := fmt.Sprintf("bytes=%d-%d",
  208. chunk.start, chunk.start+chunk.size-1)
  209. in.Range = &rng
  210. req, resp := d.ctx.S3.GetObjectRequest(in)
  211. req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
  212. err := req.Send()
  213. if err != nil {
  214. d.setErr(err)
  215. } else {
  216. d.setTotalBytes(resp) // Set total if not yet set.
  217. n, err := io.Copy(&chunk, resp.Body)
  218. resp.Body.Close()
  219. if err != nil {
  220. d.setErr(err)
  221. }
  222. d.incrWritten(n)
  223. }
  224. }
  225. // getTotalBytes is a thread-safe getter for retrieving the total byte status.
  226. func (d *downloader) getTotalBytes() int64 {
  227. d.m.Lock()
  228. defer d.m.Unlock()
  229. return d.totalBytes
  230. }
  231. // setTotalBytes is a thread-safe setter for setting the total byte status.
  232. // Will extract the object's total bytes from the Content-Range if the file
  233. // will be chunked, or Content-Length. Content-Length is used when the response
  234. // does not include a Content-Range. Meaning the object was not chunked. This
  235. // occurs when the full file fits within the PartSize directive.
  236. func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
  237. d.m.Lock()
  238. defer d.m.Unlock()
  239. if d.totalBytes >= 0 {
  240. return
  241. }
  242. if resp.ContentRange == nil {
  243. // ContentRange is nil when the full file contents is provied, and
  244. // is not chunked. Use ContentLength instead.
  245. if resp.ContentLength != nil {
  246. d.totalBytes = *resp.ContentLength
  247. return
  248. }
  249. } else {
  250. parts := strings.Split(*resp.ContentRange, "/")
  251. total := int64(-1)
  252. var err error
  253. // Checking for whether or not a numbered total exists
  254. // If one does not exist, we will assume the total to be -1, undefined,
  255. // and sequentially download each chunk until hitting a 416 error
  256. totalStr := parts[len(parts)-1]
  257. if totalStr != "*" {
  258. total, err = strconv.ParseInt(totalStr, 10, 64)
  259. if err != nil {
  260. d.err = err
  261. return
  262. }
  263. }
  264. d.totalBytes = total
  265. }
  266. }
  267. func (d *downloader) incrWritten(n int64) {
  268. d.m.Lock()
  269. defer d.m.Unlock()
  270. d.written += n
  271. }
  272. // getErr is a thread-safe getter for the error object
  273. func (d *downloader) getErr() error {
  274. d.m.Lock()
  275. defer d.m.Unlock()
  276. return d.err
  277. }
  278. // setErr is a thread-safe setter for the error object
  279. func (d *downloader) setErr(e error) {
  280. d.m.Lock()
  281. defer d.m.Unlock()
  282. d.err = e
  283. }
  284. // dlchunk represents a single chunk of data to write by the worker routine.
  285. // This structure also implements an io.SectionReader style interface for
  286. // io.WriterAt, effectively making it an io.SectionWriter (which does not
  287. // exist).
  288. type dlchunk struct {
  289. w io.WriterAt
  290. start int64
  291. size int64
  292. cur int64
  293. }
  294. // Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
  295. // position to its end (or EOF).
  296. func (c *dlchunk) Write(p []byte) (n int, err error) {
  297. if c.cur >= c.size {
  298. return 0, io.EOF
  299. }
  300. n, err = c.w.WriteAt(p, c.start+c.cur)
  301. c.cur += int64(n)
  302. return
  303. }