provider.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. // Package endpointcreds provides support for retrieving credentials from an
  2. // arbitrary HTTP endpoint.
  3. //
  4. // The credentials endpoint Provider can receive both static and refreshable
  5. // credentials that will expire. Credentials are static when an "Expiration"
  6. // value is not provided in the endpoint's response.
  7. //
  8. // Static credentials will never expire once they have been retrieved. The format
  9. // of the static credentials response:
  10. // {
  11. // "AccessKeyId" : "MUA...",
  12. // "SecretAccessKey" : "/7PC5om....",
  13. // }
  14. //
  15. // Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
  16. // value in the response. The format of the refreshable credentials response:
  17. // {
  18. // "AccessKeyId" : "MUA...",
  19. // "SecretAccessKey" : "/7PC5om....",
  20. // "Token" : "AQoDY....=",
  21. // "Expiration" : "2016-02-25T06:03:31Z"
  22. // }
  23. //
  24. // Errors should be returned in the following format and only returned with 400
  25. // or 500 HTTP status codes.
  26. // {
  27. // "code": "ErrorCode",
  28. // "message": "Helpful error message."
  29. // }
  30. package endpointcreds
  31. import (
  32. "encoding/json"
  33. "time"
  34. "github.com/aws/aws-sdk-go/aws"
  35. "github.com/aws/aws-sdk-go/aws/awserr"
  36. "github.com/aws/aws-sdk-go/aws/client"
  37. "github.com/aws/aws-sdk-go/aws/client/metadata"
  38. "github.com/aws/aws-sdk-go/aws/credentials"
  39. "github.com/aws/aws-sdk-go/aws/request"
  40. )
  41. // ProviderName is the name of the credentials provider.
  42. const ProviderName = `CredentialsEndpointProvider`
  43. // Provider satisfies the credentials.Provider interface, and is a client to
  44. // retrieve credentials from an arbitrary endpoint.
  45. type Provider struct {
  46. staticCreds bool
  47. credentials.Expiry
  48. // Requires a AWS Client to make HTTP requests to the endpoint with.
  49. // the Endpoint the request will be made to is provided by the aws.Config's
  50. // Endpoint value.
  51. Client *client.Client
  52. // ExpiryWindow will allow the credentials to trigger refreshing prior to
  53. // the credentials actually expiring. This is beneficial so race conditions
  54. // with expiring credentials do not cause request to fail unexpectedly
  55. // due to ExpiredTokenException exceptions.
  56. //
  57. // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
  58. // 10 seconds before the credentials are actually expired.
  59. //
  60. // If ExpiryWindow is 0 or less it will be ignored.
  61. ExpiryWindow time.Duration
  62. // Optional authorization token value if set will be used as the value of
  63. // the Authorization header of the endpoint credential request.
  64. AuthorizationToken string
  65. }
  66. // NewProviderClient returns a credentials Provider for retrieving AWS credentials
  67. // from arbitrary endpoint.
  68. func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider {
  69. p := &Provider{
  70. Client: client.New(
  71. cfg,
  72. metadata.ClientInfo{
  73. ServiceName: "CredentialsEndpoint",
  74. Endpoint: endpoint,
  75. },
  76. handlers,
  77. ),
  78. }
  79. p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
  80. p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
  81. p.Client.Handlers.Validate.Clear()
  82. p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
  83. for _, option := range options {
  84. option(p)
  85. }
  86. return p
  87. }
  88. // NewCredentialsClient returns a Credentials wrapper for retrieving credentials
  89. // from an arbitrary endpoint concurrently. The client will request the
  90. func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
  91. return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
  92. }
  93. // IsExpired returns true if the credentials retrieved are expired, or not yet
  94. // retrieved.
  95. func (p *Provider) IsExpired() bool {
  96. if p.staticCreds {
  97. return false
  98. }
  99. return p.Expiry.IsExpired()
  100. }
  101. // Retrieve will attempt to request the credentials from the endpoint the Provider
  102. // was configured for. And error will be returned if the retrieval fails.
  103. func (p *Provider) Retrieve() (credentials.Value, error) {
  104. resp, err := p.getCredentials()
  105. if err != nil {
  106. return credentials.Value{ProviderName: ProviderName},
  107. awserr.New("CredentialsEndpointError", "failed to load credentials", err)
  108. }
  109. if resp.Expiration != nil {
  110. p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
  111. } else {
  112. p.staticCreds = true
  113. }
  114. return credentials.Value{
  115. AccessKeyID: resp.AccessKeyID,
  116. SecretAccessKey: resp.SecretAccessKey,
  117. SessionToken: resp.Token,
  118. ProviderName: ProviderName,
  119. }, nil
  120. }
  121. type getCredentialsOutput struct {
  122. Expiration *time.Time
  123. AccessKeyID string
  124. SecretAccessKey string
  125. Token string
  126. }
  127. type errorOutput struct {
  128. Code string `json:"code"`
  129. Message string `json:"message"`
  130. }
  131. func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
  132. op := &request.Operation{
  133. Name: "GetCredentials",
  134. HTTPMethod: "GET",
  135. }
  136. out := &getCredentialsOutput{}
  137. req := p.Client.NewRequest(op, nil, out)
  138. req.HTTPRequest.Header.Set("Accept", "application/json")
  139. if authToken := p.AuthorizationToken; len(authToken) != 0 {
  140. req.HTTPRequest.Header.Set("Authorization", authToken)
  141. }
  142. return out, req.Send()
  143. }
  144. func validateEndpointHandler(r *request.Request) {
  145. if len(r.ClientInfo.Endpoint) == 0 {
  146. r.Error = aws.ErrMissingEndpoint
  147. }
  148. }
  149. func unmarshalHandler(r *request.Request) {
  150. defer r.HTTPResponse.Body.Close()
  151. out := r.Data.(*getCredentialsOutput)
  152. if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
  153. r.Error = awserr.New("SerializationError",
  154. "failed to decode endpoint credentials",
  155. err,
  156. )
  157. }
  158. }
  159. func unmarshalError(r *request.Request) {
  160. defer r.HTTPResponse.Body.Close()
  161. var errOut errorOutput
  162. if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
  163. r.Error = awserr.New("SerializationError",
  164. "failed to decode endpoint credentials",
  165. err,
  166. )
  167. }
  168. // Response body format is not consistent between metadata endpoints.
  169. // Grab the error message as a string and include that as the source error
  170. r.Error = awserr.New(errOut.Code, errOut.Message, nil)
  171. }