generate.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. package main
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "net/url"
  7. "os"
  8. "os/exec"
  9. "reflect"
  10. "regexp"
  11. "sort"
  12. "strconv"
  13. "strings"
  14. "text/template"
  15. "github.com/aws/aws-sdk-go/private/model/api"
  16. "github.com/aws/aws-sdk-go/private/util"
  17. )
  18. // TestSuiteTypeInput input test
  19. // TestSuiteTypeInput output test
  20. const (
  21. TestSuiteTypeInput = iota
  22. TestSuiteTypeOutput
  23. )
  24. type testSuite struct {
  25. *api.API
  26. Description string
  27. Cases []testCase
  28. Type uint
  29. title string
  30. }
  31. type testCase struct {
  32. TestSuite *testSuite
  33. Given *api.Operation
  34. Params interface{} `json:",omitempty"`
  35. Data interface{} `json:"result,omitempty"`
  36. InputTest testExpectation `json:"serialized"`
  37. OutputTest testExpectation `json:"response"`
  38. }
  39. type testExpectation struct {
  40. Body string
  41. URI string
  42. Headers map[string]string
  43. StatusCode uint `json:"status_code"`
  44. }
  45. const preamble = `
  46. var _ bytes.Buffer // always import bytes
  47. var _ http.Request
  48. var _ json.Marshaler
  49. var _ time.Time
  50. var _ xmlutil.XMLNode
  51. var _ xml.Attr
  52. var _ = ioutil.Discard
  53. var _ = util.Trim("")
  54. var _ = url.Values{}
  55. var _ = io.EOF
  56. var _ = aws.String
  57. var _ = fmt.Println
  58. func init() {
  59. protocol.RandReader = &awstesting.ZeroReader{}
  60. }
  61. `
  62. var reStripSpace = regexp.MustCompile(`\s(\w)`)
  63. var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
  64. func removeImports(code string) string {
  65. return reImportRemoval.ReplaceAllString(code, "")
  66. }
  67. var extraImports = []string{
  68. "bytes",
  69. "encoding/json",
  70. "encoding/xml",
  71. "fmt",
  72. "io",
  73. "io/ioutil",
  74. "net/http",
  75. "testing",
  76. "time",
  77. "net/url",
  78. "",
  79. "github.com/aws/aws-sdk-go/awstesting",
  80. "github.com/aws/aws-sdk-go/awstesting/unit",
  81. "github.com/aws/aws-sdk-go/private/protocol",
  82. "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
  83. "github.com/aws/aws-sdk-go/private/util",
  84. "github.com/stretchr/testify/assert",
  85. }
  86. func addImports(code string) string {
  87. importNames := make([]string, len(extraImports))
  88. for i, n := range extraImports {
  89. if n != "" {
  90. importNames[i] = fmt.Sprintf("%q", n)
  91. }
  92. }
  93. str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
  94. return str
  95. }
  96. func (t *testSuite) TestSuite() string {
  97. var buf bytes.Buffer
  98. t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
  99. return strings.ToUpper(x[1:])
  100. })
  101. t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
  102. for idx, c := range t.Cases {
  103. c.TestSuite = t
  104. buf.WriteString(c.TestCase(idx) + "\n")
  105. }
  106. return buf.String()
  107. }
  108. var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
  109. func Test{{ .OpName }}(t *testing.T) {
  110. svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
  111. {{ if ne .ParamsString "" }}input := {{ .ParamsString }}
  112. req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
  113. r := req.HTTPRequest
  114. // build request
  115. {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req)
  116. assert.NoError(t, req.Error)
  117. {{ if ne .TestCase.InputTest.Body "" }}// assert body
  118. assert.NotNil(t, r.Body)
  119. {{ .BodyAssertions }}{{ end }}
  120. {{ if ne .TestCase.InputTest.URI "" }}// assert URL
  121. awstesting.AssertURL(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }}
  122. // assert headers
  123. {{ range $k, $v := .TestCase.InputTest.Headers }}assert.Equal(t, "{{ $v }}", r.Header.Get("{{ $k }}"))
  124. {{ end }}
  125. }
  126. `))
  127. type tplInputTestCaseData struct {
  128. TestCase *testCase
  129. OpName, ParamsString string
  130. }
  131. func (t tplInputTestCaseData) BodyAssertions() string {
  132. code := &bytes.Buffer{}
  133. protocol := t.TestCase.TestSuite.API.Metadata.Protocol
  134. // Extract the body bytes
  135. switch protocol {
  136. case "rest-xml":
  137. fmt.Fprintln(code, "body := util.SortXML(r.Body)")
  138. default:
  139. fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
  140. }
  141. // Generate the body verification code
  142. expectedBody := util.Trim(t.TestCase.InputTest.Body)
  143. switch protocol {
  144. case "ec2", "query":
  145. fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
  146. expectedBody)
  147. case "rest-xml":
  148. if strings.HasPrefix(expectedBody, "<") {
  149. fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})",
  150. expectedBody, t.TestCase.Given.InputRef.ShapeName)
  151. } else {
  152. fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
  153. expectedBody)
  154. }
  155. case "json", "jsonrpc", "rest-json":
  156. if strings.HasPrefix(expectedBody, "{") {
  157. fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
  158. expectedBody)
  159. } else {
  160. fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
  161. expectedBody)
  162. }
  163. default:
  164. fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
  165. expectedBody)
  166. }
  167. return code.String()
  168. }
  169. var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
  170. func Test{{ .OpName }}(t *testing.T) {
  171. svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
  172. buf := bytes.NewReader([]byte({{ .Body }}))
  173. req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
  174. req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
  175. // set headers
  176. {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
  177. {{ end }}
  178. // unmarshal response
  179. {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req)
  180. {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req)
  181. assert.NoError(t, req.Error)
  182. // assert response
  183. assert.NotNil(t, out) // ensure out variable is used
  184. {{ .Assertions }}
  185. }
  186. `))
  187. type tplOutputTestCaseData struct {
  188. TestCase *testCase
  189. Body, OpName, Assertions string
  190. }
  191. func (i *testCase) TestCase(idx int) string {
  192. var buf bytes.Buffer
  193. opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
  194. if i.TestSuite.Type == TestSuiteTypeInput { // input test
  195. // query test should sort body as form encoded values
  196. switch i.TestSuite.API.Metadata.Protocol {
  197. case "query", "ec2":
  198. m, _ := url.ParseQuery(i.InputTest.Body)
  199. i.InputTest.Body = m.Encode()
  200. case "rest-xml":
  201. i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
  202. case "json", "rest-json":
  203. i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1)
  204. }
  205. input := tplInputTestCaseData{
  206. TestCase: i,
  207. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  208. ParamsString: api.ParamsStructFromJSON(i.Params, i.Given.InputRef.Shape, false),
  209. }
  210. if err := tplInputTestCase.Execute(&buf, input); err != nil {
  211. panic(err)
  212. }
  213. } else if i.TestSuite.Type == TestSuiteTypeOutput {
  214. output := tplOutputTestCaseData{
  215. TestCase: i,
  216. Body: fmt.Sprintf("%q", i.OutputTest.Body),
  217. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  218. Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
  219. }
  220. if err := tplOutputTestCase.Execute(&buf, output); err != nil {
  221. panic(err)
  222. }
  223. }
  224. return buf.String()
  225. }
  226. // generateTestSuite generates a protocol test suite for a given configuration
  227. // JSON protocol test file.
  228. func generateTestSuite(filename string) string {
  229. inout := "Input"
  230. if strings.Contains(filename, "output/") {
  231. inout = "Output"
  232. }
  233. var suites []testSuite
  234. f, err := os.Open(filename)
  235. if err != nil {
  236. panic(err)
  237. }
  238. err = json.NewDecoder(f).Decode(&suites)
  239. if err != nil {
  240. panic(err)
  241. }
  242. var buf bytes.Buffer
  243. buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
  244. var innerBuf bytes.Buffer
  245. innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
  246. for i, suite := range suites {
  247. svcPrefix := inout + "Service" + strconv.Itoa(i+1)
  248. suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
  249. suite.API.Operations = map[string]*api.Operation{}
  250. for idx, c := range suite.Cases {
  251. c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
  252. suite.API.Operations[c.Given.ExportedName] = c.Given
  253. }
  254. suite.Type = getType(inout)
  255. suite.API.NoInitMethods = true // don't generate init methods
  256. suite.API.NoStringerMethods = true // don't generate stringer methods
  257. suite.API.NoConstServiceNames = true // don't generate service names
  258. suite.API.Setup()
  259. suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
  260. // Sort in order for deterministic test generation
  261. names := make([]string, 0, len(suite.API.Shapes))
  262. for n := range suite.API.Shapes {
  263. names = append(names, n)
  264. }
  265. sort.Strings(names)
  266. for _, name := range names {
  267. s := suite.API.Shapes[name]
  268. s.Rename(svcPrefix + "TestShape" + name)
  269. }
  270. svcCode := addImports(suite.API.ServiceGoCode())
  271. if i == 0 {
  272. importMatch := reImportRemoval.FindStringSubmatch(svcCode)
  273. buf.WriteString(importMatch[0] + "\n\n")
  274. buf.WriteString(preamble + "\n\n")
  275. }
  276. svcCode = removeImports(svcCode)
  277. svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
  278. svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
  279. svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
  280. buf.WriteString(svcCode + "\n\n")
  281. apiCode := removeImports(suite.API.APIGoCode())
  282. apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
  283. apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
  284. apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
  285. buf.WriteString(apiCode + "\n\n")
  286. innerBuf.WriteString(suite.TestSuite() + "\n")
  287. }
  288. return buf.String() + innerBuf.String()
  289. }
  290. // findMember searches the shape for the member with the matching key name.
  291. func findMember(shape *api.Shape, key string) string {
  292. for actualKey := range shape.MemberRefs {
  293. if strings.ToLower(key) == strings.ToLower(actualKey) {
  294. return actualKey
  295. }
  296. }
  297. return ""
  298. }
  299. // GenerateAssertions builds assertions for a shape based on its type.
  300. //
  301. // The shape's recursive values also will have assertions generated for them.
  302. func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
  303. switch t := out.(type) {
  304. case map[string]interface{}:
  305. keys := util.SortedKeys(t)
  306. code := ""
  307. if shape.Type == "map" {
  308. for _, k := range keys {
  309. v := t[k]
  310. s := shape.ValueRef.Shape
  311. code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
  312. }
  313. } else {
  314. for _, k := range keys {
  315. v := t[k]
  316. m := findMember(shape, k)
  317. s := shape.MemberRefs[m].Shape
  318. code += GenerateAssertions(v, s, prefix+"."+m+"")
  319. }
  320. }
  321. return code
  322. case []interface{}:
  323. code := ""
  324. for i, v := range t {
  325. s := shape.MemberRef.Shape
  326. code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
  327. }
  328. return code
  329. default:
  330. switch shape.Type {
  331. case "timestamp":
  332. return fmt.Sprintf("assert.Equal(t, time.Unix(%#v, 0).UTC().String(), %s.String())\n", out, prefix)
  333. case "blob":
  334. return fmt.Sprintf("assert.Equal(t, %#v, string(%s))\n", out, prefix)
  335. case "integer", "long":
  336. return fmt.Sprintf("assert.Equal(t, int64(%#v), *%s)\n", out, prefix)
  337. default:
  338. if !reflect.ValueOf(out).IsValid() {
  339. return fmt.Sprintf("assert.Nil(t, %s)\n", prefix)
  340. }
  341. return fmt.Sprintf("assert.Equal(t, %#v, *%s)\n", out, prefix)
  342. }
  343. }
  344. }
  345. func getType(t string) uint {
  346. switch t {
  347. case "Input":
  348. return TestSuiteTypeInput
  349. case "Output":
  350. return TestSuiteTypeOutput
  351. default:
  352. panic("Invalid type for test suite")
  353. }
  354. }
  355. func main() {
  356. out := generateTestSuite(os.Args[1])
  357. if len(os.Args) == 3 {
  358. f, err := os.Create(os.Args[2])
  359. defer f.Close()
  360. if err != nil {
  361. panic(err)
  362. }
  363. f.WriteString(util.GoFmt(out))
  364. f.Close()
  365. c := exec.Command("gofmt", "-s", "-w", os.Args[2])
  366. if err := c.Run(); err != nil {
  367. panic(err)
  368. }
  369. } else {
  370. fmt.Println(out)
  371. }
  372. }