credentials.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package cloudwatch
  2. import (
  3. "fmt"
  4. "os"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/aws/aws-sdk-go/aws"
  9. "github.com/aws/aws-sdk-go/aws/credentials"
  10. "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
  11. "github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
  12. "github.com/aws/aws-sdk-go/aws/ec2metadata"
  13. "github.com/aws/aws-sdk-go/aws/session"
  14. "github.com/aws/aws-sdk-go/service/sts"
  15. )
  16. type DatasourceInfo struct {
  17. Profile string
  18. Region string
  19. AuthType string
  20. AssumeRoleArn string
  21. Namespace string
  22. AccessKey string
  23. SecretKey string
  24. }
  25. type cache struct {
  26. credential *credentials.Credentials
  27. expiration *time.Time
  28. }
  29. var awsCredentialCache map[string]cache = make(map[string]cache)
  30. var credentialCacheLock sync.RWMutex
  31. func GetCredentials(dsInfo *DatasourceInfo) (*credentials.Credentials, error) {
  32. cacheKey := dsInfo.AccessKey + ":" + dsInfo.Profile + ":" + dsInfo.AssumeRoleArn
  33. credentialCacheLock.RLock()
  34. if _, ok := awsCredentialCache[cacheKey]; ok {
  35. if awsCredentialCache[cacheKey].expiration != nil &&
  36. (*awsCredentialCache[cacheKey].expiration).After(time.Now().UTC()) {
  37. result := awsCredentialCache[cacheKey].credential
  38. credentialCacheLock.RUnlock()
  39. return result, nil
  40. }
  41. }
  42. credentialCacheLock.RUnlock()
  43. accessKeyId := ""
  44. secretAccessKey := ""
  45. sessionToken := ""
  46. var expiration *time.Time
  47. expiration = nil
  48. if dsInfo.AuthType == "arn" && strings.Index(dsInfo.AssumeRoleArn, "arn:aws:iam:") == 0 {
  49. params := &sts.AssumeRoleInput{
  50. RoleArn: aws.String(dsInfo.AssumeRoleArn),
  51. RoleSessionName: aws.String("GrafanaSession"),
  52. DurationSeconds: aws.Int64(900),
  53. }
  54. stsSess, err := session.NewSession()
  55. if err != nil {
  56. return nil, err
  57. }
  58. stsCreds := credentials.NewChainCredentials(
  59. []credentials.Provider{
  60. &credentials.EnvProvider{},
  61. &credentials.SharedCredentialsProvider{Filename: "", Profile: dsInfo.Profile},
  62. remoteCredProvider(stsSess),
  63. })
  64. stsConfig := &aws.Config{
  65. Region: aws.String(dsInfo.Region),
  66. Credentials: stsCreds,
  67. }
  68. sess, err := session.NewSession(stsConfig)
  69. if err != nil {
  70. return nil, err
  71. }
  72. svc := sts.New(sess, stsConfig)
  73. resp, err := svc.AssumeRole(params)
  74. if err != nil {
  75. return nil, err
  76. }
  77. if resp.Credentials != nil {
  78. accessKeyId = *resp.Credentials.AccessKeyId
  79. secretAccessKey = *resp.Credentials.SecretAccessKey
  80. sessionToken = *resp.Credentials.SessionToken
  81. expiration = resp.Credentials.Expiration
  82. }
  83. } else {
  84. now := time.Now()
  85. e := now.Add(5 * time.Minute)
  86. expiration = &e
  87. }
  88. sess, err := session.NewSession()
  89. if err != nil {
  90. return nil, err
  91. }
  92. creds := credentials.NewChainCredentials(
  93. []credentials.Provider{
  94. &credentials.StaticProvider{Value: credentials.Value{
  95. AccessKeyID: accessKeyId,
  96. SecretAccessKey: secretAccessKey,
  97. SessionToken: sessionToken,
  98. }},
  99. &credentials.EnvProvider{},
  100. &credentials.StaticProvider{Value: credentials.Value{
  101. AccessKeyID: dsInfo.AccessKey,
  102. SecretAccessKey: dsInfo.SecretKey,
  103. }},
  104. &credentials.SharedCredentialsProvider{Filename: "", Profile: dsInfo.Profile},
  105. remoteCredProvider(sess),
  106. })
  107. credentialCacheLock.Lock()
  108. awsCredentialCache[cacheKey] = cache{
  109. credential: creds,
  110. expiration: expiration,
  111. }
  112. credentialCacheLock.Unlock()
  113. return creds, nil
  114. }
  115. func remoteCredProvider(sess *session.Session) credentials.Provider {
  116. ecsCredURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
  117. if len(ecsCredURI) > 0 {
  118. return ecsCredProvider(sess, ecsCredURI)
  119. }
  120. return ec2RoleProvider(sess)
  121. }
  122. func ecsCredProvider(sess *session.Session, uri string) credentials.Provider {
  123. const host = `169.254.170.2`
  124. c := ec2metadata.New(sess)
  125. return endpointcreds.NewProviderClient(
  126. c.Client.Config,
  127. c.Client.Handlers,
  128. fmt.Sprintf("http://%s%s", host, uri),
  129. func(p *endpointcreds.Provider) { p.ExpiryWindow = 5 * time.Minute })
  130. }
  131. func ec2RoleProvider(sess *session.Session) credentials.Provider {
  132. return &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess), ExpiryWindow: 5 * time.Minute}
  133. }
  134. func getAwsConfig(dsInfo *DatasourceInfo) (*aws.Config, error) {
  135. creds, err := GetCredentials(dsInfo)
  136. if err != nil {
  137. return nil, err
  138. }
  139. cfg := &aws.Config{
  140. Region: aws.String(dsInfo.Region),
  141. Credentials: creds,
  142. }
  143. return cfg, nil
  144. }