gen.go 13 KB


  1. // Copyright (c) 2012, Sean Treadway, SoundCloud Ltd.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Source code and contact info at http://github.com/streadway/amqp
  5. // +build ignore
  6. package main
  7. import (
  8. "bytes"
  9. "encoding/xml"
  10. "errors"
  11. "fmt"
  12. "io/ioutil"
  13. "log"
  14. "os"
  15. "regexp"
  16. "strings"
  17. "text/template"
  18. )
  19. var (
  20. ErrUnknownType = errors.New("Unknown field type in gen")
  21. ErrUnknownDomain = errors.New("Unknown domain type in gen")
  22. )
  23. var amqpTypeToNative = map[string]string{
  24. "bit": "bool",
  25. "octet": "byte",
  26. "shortshort": "uint8",
  27. "short": "uint16",
  28. "long": "uint32",
  29. "longlong": "uint64",
  30. "timestamp": "time.Time",
  31. "table": "Table",
  32. "shortstr": "string",
  33. "longstr": "string",
  34. }
  35. type Rule struct {
  36. Name string `xml:"name,attr"`
  37. Docs []string `xml:"doc"`
  38. }
  39. type Doc struct {
  40. Type string `xml:"type,attr"`
  41. Body string `xml:",innerxml"`
  42. }
  43. type Chassis struct {
  44. Name string `xml:"name,attr"`
  45. Implement string `xml:"implement,attr"`
  46. }
  47. type Assert struct {
  48. Check string `xml:"check,attr"`
  49. Value string `xml:"value,attr"`
  50. Method string `xml:"method,attr"`
  51. }
  52. type Field struct {
  53. Name string `xml:"name,attr"`
  54. Domain string `xml:"domain,attr"`
  55. Type string `xml:"type,attr"`
  56. Label string `xml:"label,attr"`
  57. Reserved bool `xml:"reserved,attr"`
  58. Docs []Doc `xml:"doc"`
  59. Asserts []Assert `xml:"assert"`
  60. }
  61. type Response struct {
  62. Name string `xml:"name,attr"`
  63. }
  64. type Method struct {
  65. Name string `xml:"name,attr"`
  66. Response Response `xml:"response"`
  67. Synchronous bool `xml:"synchronous,attr"`
  68. Content bool `xml:"content,attr"`
  69. Index string `xml:"index,attr"`
  70. Label string `xml:"label,attr"`
  71. Docs []Doc `xml:"doc"`
  72. Rules []Rule `xml:"rule"`
  73. Fields []Field `xml:"field"`
  74. Chassis []Chassis `xml:"chassis"`
  75. }
  76. type Class struct {
  77. Name string `xml:"name,attr"`
  78. Handler string `xml:"handler,attr"`
  79. Index string `xml:"index,attr"`
  80. Label string `xml:"label,attr"`
  81. Docs []Doc `xml:"doc"`
  82. Methods []Method `xml:"method"`
  83. Chassis []Chassis `xml:"chassis"`
  84. }
  85. type Domain struct {
  86. Name string `xml:"name,attr"`
  87. Type string `xml:"type,attr"`
  88. Label string `xml:"label,attr"`
  89. Rules []Rule `xml:"rule"`
  90. Docs []Doc `xml:"doc"`
  91. }
  92. type Constant struct {
  93. Name string `xml:"name,attr"`
  94. Value int `xml:"value,attr"`
  95. Class string `xml:"class,attr"`
  96. Doc string `xml:"doc"`
  97. }
  98. type Amqp struct {
  99. Major int `xml:"major,attr"`
  100. Minor int `xml:"minor,attr"`
  101. Port int `xml:"port,attr"`
  102. Comment string `xml:"comment,attr"`
  103. Constants []Constant `xml:"constant"`
  104. Domains []Domain `xml:"domain"`
  105. Classes []Class `xml:"class"`
  106. }
  107. type renderer struct {
  108. Root Amqp
  109. bitcounter int
  110. }
  111. type fieldset struct {
  112. AmqpType string
  113. NativeType string
  114. Fields []Field
  115. *renderer
  116. }
  117. var (
  118. helpers = template.FuncMap{
  119. "public": public,
  120. "private": private,
  121. "clean": clean,
  122. }
  123. packageTemplate = template.Must(template.New("package").Funcs(helpers).Parse(`
  124. // Copyright (c) 2012, Sean Treadway, SoundCloud Ltd.
  125. // Use of this source code is governed by a BSD-style
  126. // license that can be found in the LICENSE file.
  127. // Source code and contact info at http://github.com/streadway/amqp
  128. /* GENERATED FILE - DO NOT EDIT */
  129. /* Rebuild from the spec/gen.go tool */
  130. {{with .Root}}
  131. package amqp
  132. import (
  133. "fmt"
  134. "encoding/binary"
  135. "io"
  136. )
  137. // Error codes that can be sent from the server during a connection or
  138. // channel exception or used by the client to indicate a class of error like
  139. // ErrCredentials. The text of the error is likely more interesting than
  140. // these constants.
  141. const (
  142. {{range $c := .Constants}}
  143. {{if $c.IsError}}{{.Name | public}}{{else}}{{.Name | private}}{{end}} = {{.Value}}{{end}}
  144. )
  145. func isSoftExceptionCode(code int) bool {
  146. switch code {
  147. {{range $c := .Constants}} {{if $c.IsSoftError}} case {{$c.Value}}:
  148. return true
  149. {{end}}{{end}}
  150. }
  151. return false
  152. }
  153. {{range .Classes}}
  154. {{$class := .}}
  155. {{range .Methods}}
  156. {{$method := .}}
  157. {{$struct := $.StructName $class.Name $method.Name}}
  158. {{if .Docs}}/* {{range .Docs}} {{.Body | clean}} {{end}} */{{end}}
  159. type {{$struct}} struct {
  160. {{range .Fields}}
  161. {{$.FieldName .}} {{$.FieldType . | $.NativeType}} {{if .Label}}// {{.Label}}{{end}}{{end}}
  162. {{if .Content}}Properties properties
  163. Body []byte{{end}}
  164. }
  165. func (me *{{$struct}}) id() (uint16, uint16) {
  166. return {{$class.Index}}, {{$method.Index}}
  167. }
  168. func (me *{{$struct}}) wait() (bool) {
  169. return {{.Synchronous}}{{if $.HasField "NoWait" .}} && !me.NoWait{{end}}
  170. }
  171. {{if .Content}}
  172. func (me *{{$struct}}) getContent() (properties, []byte) {
  173. return me.Properties, me.Body
  174. }
  175. func (me *{{$struct}}) setContent(props properties, body []byte) {
  176. me.Properties, me.Body = props, body
  177. }
  178. {{end}}
  179. func (me *{{$struct}}) write(w io.Writer) (err error) {
  180. {{if $.HasType "bit" $method}}var bits byte{{end}}
  181. {{.Fields | $.Fieldsets | $.Partial "enc-"}}
  182. return
  183. }
  184. func (me *{{$struct}}) read(r io.Reader) (err error) {
  185. {{if $.HasType "bit" $method}}var bits byte{{end}}
  186. {{.Fields | $.Fieldsets | $.Partial "dec-"}}
  187. return
  188. }
  189. {{end}}
  190. {{end}}
  191. func (me *reader) parseMethodFrame(channel uint16, size uint32) (f frame, err error) {
  192. mf := &methodFrame {
  193. ChannelId: channel,
  194. }
  195. if err = binary.Read(me.r, binary.BigEndian, &mf.ClassId); err != nil {
  196. return
  197. }
  198. if err = binary.Read(me.r, binary.BigEndian, &mf.MethodId); err != nil {
  199. return
  200. }
  201. switch mf.ClassId {
  202. {{range .Classes}}
  203. {{$class := .}}
  204. case {{.Index}}: // {{.Name}}
  205. switch mf.MethodId {
  206. {{range .Methods}}
  207. case {{.Index}}: // {{$class.Name}} {{.Name}}
  208. //fmt.Println("NextMethod: class:{{$class.Index}} method:{{.Index}}")
  209. method := &{{$.StructName $class.Name .Name}}{}
  210. if err = method.read(me.r); err != nil {
  211. return
  212. }
  213. mf.Method = method
  214. {{end}}
  215. default:
  216. return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId)
  217. }
  218. {{end}}
  219. default:
  220. return nil, fmt.Errorf("Bad method frame, unknown class %d", mf.ClassId)
  221. }
  222. return mf, nil
  223. }
  224. {{end}}
  225. {{define "enc-bit"}}
  226. {{range $off, $field := .Fields}}
  227. if me.{{$field | $.FieldName}} { bits |= 1 << {{$off}} }
  228. {{end}}
  229. if err = binary.Write(w, binary.BigEndian, bits); err != nil { return }
  230. {{end}}
  231. {{define "enc-octet"}}
  232. {{range .Fields}} if err = binary.Write(w, binary.BigEndian, me.{{. | $.FieldName}}); err != nil { return }
  233. {{end}}
  234. {{end}}
  235. {{define "enc-shortshort"}}
  236. {{range .Fields}} if err = binary.Write(w, binary.BigEndian, me.{{. | $.FieldName}}); err != nil { return }
  237. {{end}}
  238. {{end}}
  239. {{define "enc-short"}}
  240. {{range .Fields}} if err = binary.Write(w, binary.BigEndian, me.{{. | $.FieldName}}); err != nil { return }
  241. {{end}}
  242. {{end}}
  243. {{define "enc-long"}}
  244. {{range .Fields}} if err = binary.Write(w, binary.BigEndian, me.{{. | $.FieldName}}); err != nil { return }
  245. {{end}}
  246. {{end}}
  247. {{define "enc-longlong"}}
  248. {{range .Fields}} if err = binary.Write(w, binary.BigEndian, me.{{. | $.FieldName}}); err != nil { return }
  249. {{end}}
  250. {{end}}
  251. {{define "enc-timestamp"}}
  252. {{range .Fields}} if err = writeTimestamp(w, me.{{. | $.FieldName}}); err != nil { return }
  253. {{end}}
  254. {{end}}
  255. {{define "enc-shortstr"}}
  256. {{range .Fields}} if err = writeShortstr(w, me.{{. | $.FieldName}}); err != nil { return }
  257. {{end}}
  258. {{end}}
  259. {{define "enc-longstr"}}
  260. {{range .Fields}} if err = writeLongstr(w, me.{{. | $.FieldName}}); err != nil { return }
  261. {{end}}
  262. {{end}}
  263. {{define "enc-table"}}
  264. {{range .Fields}} if err = writeTable(w, me.{{. | $.FieldName}}); err != nil { return }
  265. {{end}}
  266. {{end}}
  267. {{define "dec-bit"}}
  268. if err = binary.Read(r, binary.BigEndian, &bits); err != nil {
  269. return
  270. }
  271. {{range $off, $field := .Fields}} me.{{$field | $.FieldName}} = (bits & (1 << {{$off}}) > 0)
  272. {{end}}
  273. {{end}}
  274. {{define "dec-octet"}}
  275. {{range .Fields}} if err = binary.Read(r, binary.BigEndian, &me.{{. | $.FieldName}}); err != nil { return }
  276. {{end}}
  277. {{end}}
  278. {{define "dec-shortshort"}}
  279. {{range .Fields}} if err = binary.Read(r, binary.BigEndian, &me.{{. | $.FieldName}}); err != nil { return }
  280. {{end}}
  281. {{end}}
  282. {{define "dec-short"}}
  283. {{range .Fields}} if err = binary.Read(r, binary.BigEndian, &me.{{. | $.FieldName}}); err != nil { return }
  284. {{end}}
  285. {{end}}
  286. {{define "dec-long"}}
  287. {{range .Fields}} if err = binary.Read(r, binary.BigEndian, &me.{{. | $.FieldName}}); err != nil { return }
  288. {{end}}
  289. {{end}}
  290. {{define "dec-longlong"}}
  291. {{range .Fields}} if err = binary.Read(r, binary.BigEndian, &me.{{. | $.FieldName}}); err != nil { return }
  292. {{end}}
  293. {{end}}
  294. {{define "dec-timestamp"}}
  295. {{range .Fields}} if me.{{. | $.FieldName}}, err = readTimestamp(r); err != nil { return }
  296. {{end}}
  297. {{end}}
  298. {{define "dec-shortstr"}}
  299. {{range .Fields}} if me.{{. | $.FieldName}}, err = readShortstr(r); err != nil { return }
  300. {{end}}
  301. {{end}}
  302. {{define "dec-longstr"}}
  303. {{range .Fields}} if me.{{. | $.FieldName}}, err = readLongstr(r); err != nil { return }
  304. {{end}}
  305. {{end}}
  306. {{define "dec-table"}}
  307. {{range .Fields}} if me.{{. | $.FieldName}}, err = readTable(r); err != nil { return }
  308. {{end}}
  309. {{end}}
  310. `))
  311. )
  312. func (me *Constant) IsError() bool {
  313. return strings.Contains(me.Class, "error")
  314. }
  315. func (me *Constant) IsSoftError() bool {
  316. return me.Class == "soft-error"
  317. }
  318. func (me *renderer) Partial(prefix string, fields []fieldset) (s string, err error) {
  319. var buf bytes.Buffer
  320. for _, set := range fields {
  321. name := prefix + set.AmqpType
  322. t := packageTemplate.Lookup(name)
  323. if t == nil {
  324. return "", errors.New(fmt.Sprintf("Missing template: %s", name))
  325. }
  326. if err = t.Execute(&buf, set); err != nil {
  327. return
  328. }
  329. }
  330. return string(buf.Bytes()), nil
  331. }
  332. // Groups the fields so that the right encoder/decoder can be called
  333. func (me *renderer) Fieldsets(fields []Field) (f []fieldset, err error) {
  334. if len(fields) > 0 {
  335. for _, field := range fields {
  336. cur := fieldset{}
  337. cur.AmqpType, err = me.FieldType(field)
  338. if err != nil {
  339. return
  340. }
  341. cur.NativeType, err = me.NativeType(cur.AmqpType)
  342. if err != nil {
  343. return
  344. }
  345. cur.Fields = append(cur.Fields, field)
  346. f = append(f, cur)
  347. }
  348. i, j := 0, 1
  349. for j < len(f) {
  350. if f[i].AmqpType == f[j].AmqpType {
  351. f[i].Fields = append(f[i].Fields, f[j].Fields...)
  352. } else {
  353. i++
  354. f[i] = f[j]
  355. }
  356. j++
  357. }
  358. return f[:i+1], nil
  359. }
  360. return
  361. }
  362. func (me *renderer) HasType(typ string, method Method) bool {
  363. for _, f := range method.Fields {
  364. name, _ := me.FieldType(f)
  365. if name == typ {
  366. return true
  367. }
  368. }
  369. return false
  370. }
  371. func (me *renderer) HasField(field string, method Method) bool {
  372. for _, f := range method.Fields {
  373. name := me.FieldName(f)
  374. if name == field {
  375. return true
  376. }
  377. }
  378. return false
  379. }
  380. func (me *renderer) Domain(field Field) (domain Domain, err error) {
  381. for _, domain = range me.Root.Domains {
  382. if field.Domain == domain.Name {
  383. return
  384. }
  385. }
  386. return domain, nil
  387. //return domain, ErrUnknownDomain
  388. }
  389. func (me *renderer) FieldName(field Field) (t string) {
  390. t = public(field.Name)
  391. if field.Reserved {
  392. t = strings.ToLower(t)
  393. }
  394. return
  395. }
  396. func (me *renderer) FieldType(field Field) (t string, err error) {
  397. t = field.Type
  398. if t == "" {
  399. var domain Domain
  400. domain, err = me.Domain(field)
  401. if err != nil {
  402. return "", err
  403. }
  404. t = domain.Type
  405. }
  406. return
  407. }
  408. func (me *renderer) NativeType(amqpType string) (t string, err error) {
  409. if t, ok := amqpTypeToNative[amqpType]; ok {
  410. return t, nil
  411. }
  412. return "", ErrUnknownType
  413. }
  414. func (me *renderer) Tag(d Domain) string {
  415. label := "`"
  416. label += `domain:"` + d.Name + `"`
  417. if len(d.Type) > 0 {
  418. label += `,type:"` + d.Type + `"`
  419. }
  420. label += "`"
  421. return label
  422. }
  423. func (me *renderer) StructName(parts ...string) string {
  424. return parts[0] + public(parts[1:]...)
  425. }
  426. func clean(body string) (res string) {
  427. return strings.Replace(body, "\r", "", -1)
  428. }
  429. func private(parts ...string) string {
  430. return export(regexp.MustCompile(`[-_]\w`), parts...)
  431. }
  432. func public(parts ...string) string {
  433. return export(regexp.MustCompile(`^\w|[-_]\w`), parts...)
  434. }
  435. func export(delim *regexp.Regexp, parts ...string) (res string) {
  436. for _, in := range parts {
  437. res += delim.ReplaceAllStringFunc(in, func(match string) string {
  438. switch len(match) {
  439. case 1:
  440. return strings.ToUpper(match)
  441. case 2:
  442. return strings.ToUpper(match[1:])
  443. }
  444. panic("unreachable")
  445. })
  446. }
  447. return
  448. }
  449. func main() {
  450. var r renderer
  451. spec, err := ioutil.ReadAll(os.Stdin)
  452. if err != nil {
  453. log.Fatalln("Please pass spec on stdin", err)
  454. }
  455. err = xml.Unmarshal(spec, &r.Root)
  456. if err != nil {
  457. log.Fatalln("Could not parse XML:", err)
  458. }
  459. if err = packageTemplate.Execute(os.Stdout, &r); err != nil {
  460. log.Fatalln("Generate error: ", err)
  461. }
  462. }