| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- package s3manager
- import (
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "sync"
- "github.com/aws/aws-sdk-go/aws/awserr"
- "github.com/aws/aws-sdk-go/aws/awsutil"
- "github.com/aws/aws-sdk-go/aws/client"
- "github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/s3"
- "github.com/aws/aws-sdk-go/service/s3/s3iface"
- )
- // DefaultDownloadPartSize is the default range of bytes to get at a time when
- // using Download().
- const DefaultDownloadPartSize = 1024 * 1024 * 5
- // DefaultDownloadConcurrency is the default number of goroutines to spin up
- // when using Download().
- const DefaultDownloadConcurrency = 5
- // The Downloader structure that calls Download(). It is safe to call Download()
- // on this structure for multiple objects and across concurrent goroutines.
- // Mutating the Downloader's properties is not safe to be done concurrently.
- type Downloader struct {
- // The buffer size (in bytes) to use when buffering data into chunks and
- // sending them as parts to S3. The minimum allowed part size is 5MB, and
- // if this value is set to zero, the DefaultPartSize value will be used.
- PartSize int64
- // The number of goroutines to spin up in parallel when sending parts.
- // If this is set to zero, the DefaultDownloadConcurrency value will be used.
- Concurrency int
- // An S3 client to use when performing downloads.
- S3 s3iface.S3API
- }
- // NewDownloader creates a new Downloader instance to downloads objects from
- // S3 in concurrent chunks. Pass in additional functional options to customize
- // the downloader behavior. Requires a client.ConfigProvider in order to create
- // a S3 service client. The session.Session satisfies the client.ConfigProvider
- // interface.
- //
- // Example:
- // // The session the S3 Downloader will use
- // sess, err := session.NewSession()
- //
- // // Create a downloader with the session and default options
- // downloader := s3manager.NewDownloader(sess)
- //
- // // Create a downloader with the session and custom options
- // downloader := s3manager.NewDownloader(sess, func(d *s3manager.Uploader) {
- // d.PartSize = 64 * 1024 * 1024 // 64MB per part
- // })
- func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
- d := &Downloader{
- S3: s3.New(c),
- PartSize: DefaultDownloadPartSize,
- Concurrency: DefaultDownloadConcurrency,
- }
- for _, option := range options {
- option(d)
- }
- return d
- }
- // NewDownloaderWithClient creates a new Downloader instance to downloads
- // objects from S3 in concurrent chunks. Pass in additional functional
- // options to customize the downloader behavior. Requires a S3 service client
- // to make S3 API calls.
- //
- // Example:
- // // The session the S3 Downloader will use
- // sess, err := session.NewSession()
- //
- // // The S3 client the S3 Downloader will use
- // s3Svc := s3.new(sess)
- //
- // // Create a downloader with the s3 client and default options
- // downloader := s3manager.NewDownloaderWithClient(s3Svc)
- //
- // // Create a downloader with the s3 client and custom options
- // downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Uploader) {
- // d.PartSize = 64 * 1024 * 1024 // 64MB per part
- // })
- func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
- d := &Downloader{
- S3: svc,
- PartSize: DefaultDownloadPartSize,
- Concurrency: DefaultDownloadConcurrency,
- }
- for _, option := range options {
- option(d)
- }
- return d
- }
- // Download downloads an object in S3 and writes the payload into w using
- // concurrent GET requests.
- //
- // Additional functional options can be provided to configure the individual
- // upload. These options are copies of the Uploader instance Upload is called from.
- // Modifying the options will not impact the original Uploader instance.
- //
- // It is safe to call this method concurrently across goroutines.
- //
- // The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
- // downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
- func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
- impl := downloader{w: w, in: input, ctx: d}
- for _, option := range options {
- option(&impl.ctx)
- }
- return impl.download()
- }
- // downloader is the implementation structure used internally by Downloader.
- type downloader struct {
- ctx Downloader
- in *s3.GetObjectInput
- w io.WriterAt
- wg sync.WaitGroup
- m sync.Mutex
- pos int64
- totalBytes int64
- written int64
- err error
- }
- // init initializes the downloader with default options.
- func (d *downloader) init() {
- d.totalBytes = -1
- if d.ctx.Concurrency == 0 {
- d.ctx.Concurrency = DefaultDownloadConcurrency
- }
- if d.ctx.PartSize == 0 {
- d.ctx.PartSize = DefaultDownloadPartSize
- }
- }
- // download performs the implementation of the object download across ranged
- // GETs.
- func (d *downloader) download() (n int64, err error) {
- d.init()
- // Spin off first worker to check additional header information
- d.getChunk()
- if total := d.getTotalBytes(); total >= 0 {
- // Spin up workers
- ch := make(chan dlchunk, d.ctx.Concurrency)
- for i := 0; i < d.ctx.Concurrency; i++ {
- d.wg.Add(1)
- go d.downloadPart(ch)
- }
- // Assign work
- for d.getErr() == nil {
- if d.pos >= total {
- break // We're finished queuing chunks
- }
- // Queue the next range of bytes to read.
- ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
- d.pos += d.ctx.PartSize
- }
- // Wait for completion
- close(ch)
- d.wg.Wait()
- } else {
- // Checking if we read anything new
- for d.err == nil {
- d.getChunk()
- }
- // We expect a 416 error letting us know we are done downloading the
- // total bytes. Since we do not know the content's length, this will
- // keep grabbing chunks of data until the range of bytes specified in
- // the request is out of range of the content. Once, this happens, a
- // 416 should occur.
- e, ok := d.err.(awserr.RequestFailure)
- if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
- d.err = nil
- }
- }
- // Return error
- return d.written, d.err
- }
- // downloadPart is an individual goroutine worker reading from the ch channel
- // and performing a GetObject request on the data with a given byte range.
- //
- // If this is the first worker, this operation also resolves the total number
- // of bytes to be read so that the worker manager knows when it is finished.
- func (d *downloader) downloadPart(ch chan dlchunk) {
- defer d.wg.Done()
- for {
- chunk, ok := <-ch
- if !ok {
- break
- }
- d.downloadChunk(chunk)
- }
- }
- // getChunk grabs a chunk of data from the body.
- // Not thread safe. Should only used when grabbing data on a single thread.
- func (d *downloader) getChunk() {
- chunk := dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
- d.pos += d.ctx.PartSize
- d.downloadChunk(chunk)
- }
- // downloadChunk downloads the chunk froom s3
- func (d *downloader) downloadChunk(chunk dlchunk) {
- if d.getErr() != nil {
- return
- }
- // Get the next byte range of data
- in := &s3.GetObjectInput{}
- awsutil.Copy(in, d.in)
- rng := fmt.Sprintf("bytes=%d-%d",
- chunk.start, chunk.start+chunk.size-1)
- in.Range = &rng
- req, resp := d.ctx.S3.GetObjectRequest(in)
- req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
- err := req.Send()
- if err != nil {
- d.setErr(err)
- } else {
- d.setTotalBytes(resp) // Set total if not yet set.
- n, err := io.Copy(&chunk, resp.Body)
- resp.Body.Close()
- if err != nil {
- d.setErr(err)
- }
- d.incrWritten(n)
- }
- }
- // getTotalBytes is a thread-safe getter for retrieving the total byte status.
- func (d *downloader) getTotalBytes() int64 {
- d.m.Lock()
- defer d.m.Unlock()
- return d.totalBytes
- }
- // setTotalBytes is a thread-safe setter for setting the total byte status.
- // Will extract the object's total bytes from the Content-Range if the file
- // will be chunked, or Content-Length. Content-Length is used when the response
- // does not include a Content-Range. Meaning the object was not chunked. This
- // occurs when the full file fits within the PartSize directive.
- func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
- d.m.Lock()
- defer d.m.Unlock()
- if d.totalBytes >= 0 {
- return
- }
- if resp.ContentRange == nil {
- // ContentRange is nil when the full file contents is provied, and
- // is not chunked. Use ContentLength instead.
- if resp.ContentLength != nil {
- d.totalBytes = *resp.ContentLength
- return
- }
- } else {
- parts := strings.Split(*resp.ContentRange, "/")
- total := int64(-1)
- var err error
- // Checking for whether or not a numbered total exists
- // If one does not exist, we will assume the total to be -1, undefined,
- // and sequentially download each chunk until hitting a 416 error
- totalStr := parts[len(parts)-1]
- if totalStr != "*" {
- total, err = strconv.ParseInt(totalStr, 10, 64)
- if err != nil {
- d.err = err
- return
- }
- }
- d.totalBytes = total
- }
- }
- func (d *downloader) incrWritten(n int64) {
- d.m.Lock()
- defer d.m.Unlock()
- d.written += n
- }
- // getErr is a thread-safe getter for the error object
- func (d *downloader) getErr() error {
- d.m.Lock()
- defer d.m.Unlock()
- return d.err
- }
- // setErr is a thread-safe setter for the error object
- func (d *downloader) setErr(e error) {
- d.m.Lock()
- defer d.m.Unlock()
- d.err = e
- }
- // dlchunk represents a single chunk of data to write by the worker routine.
- // This structure also implements an io.SectionReader style interface for
- // io.WriterAt, effectively making it an io.SectionWriter (which does not
- // exist).
- type dlchunk struct {
- w io.WriterAt
- start int64
- size int64
- cur int64
- }
- // Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
- // position to its end (or EOF).
- func (c *dlchunk) Write(p []byte) (n int, err error) {
- if c.cur >= c.size {
- return 0, io.EOF
- }
- n, err = c.w.WriteAt(p, c.start+c.cur)
- c.cur += int64(n)
- return
- }
|