msg_generate.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. //+build ignore
  2. // msg_generate.go is meant to run with go generate. It will use
  3. // go/{importer,types} to track down all the RR struct types. Then for each type
  4. // it will generate pack/unpack methods based on the struct tags. The generated source is
  5. // written to zmsg.go, and is meant to be checked into git.
  6. package main
  7. import (
  8. "bytes"
  9. "fmt"
  10. "go/format"
  11. "go/importer"
  12. "go/types"
  13. "log"
  14. "os"
  15. "strings"
  16. )
  17. var packageHdr = `
  18. // *** DO NOT MODIFY ***
  19. // AUTOGENERATED BY go generate from msg_generate.go
  20. package dns
  21. `
  22. // getTypeStruct will take a type and the package scope, and return the
  23. // (innermost) struct if the type is considered a RR type (currently defined as
  24. // those structs beginning with a RR_Header, could be redefined as implementing
  25. // the RR interface). The bool return value indicates if embedded structs were
  26. // resolved.
  27. func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
  28. st, ok := t.Underlying().(*types.Struct)
  29. if !ok {
  30. return nil, false
  31. }
  32. if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
  33. return st, false
  34. }
  35. if st.Field(0).Anonymous() {
  36. st, _ := getTypeStruct(st.Field(0).Type(), scope)
  37. return st, true
  38. }
  39. return nil, false
  40. }
  41. func main() {
  42. // Import and type-check the package
  43. pkg, err := importer.Default().Import("github.com/miekg/dns")
  44. fatalIfErr(err)
  45. scope := pkg.Scope()
  46. // Collect actual types (*X)
  47. var namedTypes []string
  48. for _, name := range scope.Names() {
  49. o := scope.Lookup(name)
  50. if o == nil || !o.Exported() {
  51. continue
  52. }
  53. if st, _ := getTypeStruct(o.Type(), scope); st == nil {
  54. continue
  55. }
  56. if name == "PrivateRR" {
  57. continue
  58. }
  59. // Check if corresponding TypeX exists
  60. if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
  61. log.Fatalf("Constant Type%s does not exist.", o.Name())
  62. }
  63. namedTypes = append(namedTypes, o.Name())
  64. }
  65. b := &bytes.Buffer{}
  66. b.WriteString(packageHdr)
  67. fmt.Fprint(b, "// pack*() functions\n\n")
  68. for _, name := range namedTypes {
  69. o := scope.Lookup(name)
  70. st, _ := getTypeStruct(o.Type(), scope)
  71. fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
  72. fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
  73. if err != nil {
  74. return off, err
  75. }
  76. headerEnd := off
  77. `)
  78. for i := 1; i < st.NumFields(); i++ {
  79. o := func(s string) {
  80. fmt.Fprintf(b, s, st.Field(i).Name())
  81. fmt.Fprint(b, `if err != nil {
  82. return off, err
  83. }
  84. `)
  85. }
  86. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  87. switch st.Tag(i) {
  88. case `dns:"-"`: // ignored
  89. case `dns:"txt"`:
  90. o("off, err = packStringTxt(rr.%s, msg, off)\n")
  91. case `dns:"opt"`:
  92. o("off, err = packDataOpt(rr.%s, msg, off)\n")
  93. case `dns:"nsec"`:
  94. o("off, err = packDataNsec(rr.%s, msg, off)\n")
  95. case `dns:"domain-name"`:
  96. o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
  97. default:
  98. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  99. }
  100. continue
  101. }
  102. switch {
  103. case st.Tag(i) == `dns:"-"`: // ignored
  104. case st.Tag(i) == `dns:"cdomain-name"`:
  105. o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
  106. case st.Tag(i) == `dns:"domain-name"`:
  107. o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n")
  108. case st.Tag(i) == `dns:"a"`:
  109. o("off, err = packDataA(rr.%s, msg, off)\n")
  110. case st.Tag(i) == `dns:"aaaa"`:
  111. o("off, err = packDataAAAA(rr.%s, msg, off)\n")
  112. case st.Tag(i) == `dns:"uint48"`:
  113. o("off, err = packUint48(rr.%s, msg, off)\n")
  114. case st.Tag(i) == `dns:"txt"`:
  115. o("off, err = packString(rr.%s, msg, off)\n")
  116. case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
  117. fallthrough
  118. case st.Tag(i) == `dns:"base32"`:
  119. o("off, err = packStringBase32(rr.%s, msg, off)\n")
  120. case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
  121. fallthrough
  122. case st.Tag(i) == `dns:"base64"`:
  123. o("off, err = packStringBase64(rr.%s, msg, off)\n")
  124. case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
  125. // directly write instead of using o() so we get the error check in the correct place
  126. field := st.Field(i).Name()
  127. fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
  128. if rr.%s != "-" {
  129. off, err = packStringHex(rr.%s, msg, off)
  130. if err != nil {
  131. return off, err
  132. }
  133. }
  134. `, field, field)
  135. continue
  136. case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
  137. fallthrough
  138. case st.Tag(i) == `dns:"hex"`:
  139. o("off, err = packStringHex(rr.%s, msg, off)\n")
  140. case st.Tag(i) == `dns:"octet"`:
  141. o("off, err = packStringOctet(rr.%s, msg, off)\n")
  142. case st.Tag(i) == "":
  143. switch st.Field(i).Type().(*types.Basic).Kind() {
  144. case types.Uint8:
  145. o("off, err = packUint8(rr.%s, msg, off)\n")
  146. case types.Uint16:
  147. o("off, err = packUint16(rr.%s, msg, off)\n")
  148. case types.Uint32:
  149. o("off, err = packUint32(rr.%s, msg, off)\n")
  150. case types.Uint64:
  151. o("off, err = packUint64(rr.%s, msg, off)\n")
  152. case types.String:
  153. o("off, err = packString(rr.%s, msg, off)\n")
  154. default:
  155. log.Fatalln(name, st.Field(i).Name())
  156. }
  157. default:
  158. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  159. }
  160. }
  161. // We have packed everything, only now we know the rdlength of this RR
  162. fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
  163. fmt.Fprintln(b, "return off, nil }\n")
  164. }
  165. fmt.Fprint(b, "// unpack*() functions\n\n")
  166. for _, name := range namedTypes {
  167. o := scope.Lookup(name)
  168. st, _ := getTypeStruct(o.Type(), scope)
  169. fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
  170. fmt.Fprintf(b, "rr := new(%s)\n", name)
  171. fmt.Fprint(b, "rr.Hdr = h\n")
  172. fmt.Fprint(b, `if noRdata(h) {
  173. return rr, off, nil
  174. }
  175. var err error
  176. rdStart := off
  177. _ = rdStart
  178. `)
  179. for i := 1; i < st.NumFields(); i++ {
  180. o := func(s string) {
  181. fmt.Fprintf(b, s, st.Field(i).Name())
  182. fmt.Fprint(b, `if err != nil {
  183. return rr, off, err
  184. }
  185. `)
  186. }
  187. // size-* are special, because they reference a struct member we should use for the length.
  188. if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
  189. structMember := structMember(st.Tag(i))
  190. structTag := structTag(st.Tag(i))
  191. switch structTag {
  192. case "hex":
  193. fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  194. case "base32":
  195. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  196. case "base64":
  197. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  198. default:
  199. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  200. }
  201. fmt.Fprint(b, `if err != nil {
  202. return rr, off, err
  203. }
  204. `)
  205. continue
  206. }
  207. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  208. switch st.Tag(i) {
  209. case `dns:"-"`: // ignored
  210. case `dns:"txt"`:
  211. o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
  212. case `dns:"opt"`:
  213. o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
  214. case `dns:"nsec"`:
  215. o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
  216. case `dns:"domain-name"`:
  217. o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  218. default:
  219. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  220. }
  221. continue
  222. }
  223. switch st.Tag(i) {
  224. case `dns:"-"`: // ignored
  225. case `dns:"cdomain-name"`:
  226. fallthrough
  227. case `dns:"domain-name"`:
  228. o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
  229. case `dns:"a"`:
  230. o("rr.%s, off, err = unpackDataA(msg, off)\n")
  231. case `dns:"aaaa"`:
  232. o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
  233. case `dns:"uint48"`:
  234. o("rr.%s, off, err = unpackUint48(msg, off)\n")
  235. case `dns:"txt"`:
  236. o("rr.%s, off, err = unpackString(msg, off)\n")
  237. case `dns:"base32"`:
  238. o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  239. case `dns:"base64"`:
  240. o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  241. case `dns:"hex"`:
  242. o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  243. case `dns:"octet"`:
  244. o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
  245. case "":
  246. switch st.Field(i).Type().(*types.Basic).Kind() {
  247. case types.Uint8:
  248. o("rr.%s, off, err = unpackUint8(msg, off)\n")
  249. case types.Uint16:
  250. o("rr.%s, off, err = unpackUint16(msg, off)\n")
  251. case types.Uint32:
  252. o("rr.%s, off, err = unpackUint32(msg, off)\n")
  253. case types.Uint64:
  254. o("rr.%s, off, err = unpackUint64(msg, off)\n")
  255. case types.String:
  256. o("rr.%s, off, err = unpackString(msg, off)\n")
  257. default:
  258. log.Fatalln(name, st.Field(i).Name())
  259. }
  260. default:
  261. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  262. }
  263. // If we've hit len(msg) we return without error.
  264. if i < st.NumFields()-1 {
  265. fmt.Fprintf(b, `if off == len(msg) {
  266. return rr, off, nil
  267. }
  268. `)
  269. }
  270. }
  271. fmt.Fprintf(b, "return rr, off, err }\n\n")
  272. }
  273. // Generate typeToUnpack map
  274. fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
  275. for _, name := range namedTypes {
  276. if name == "RFC3597" {
  277. continue
  278. }
  279. fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
  280. }
  281. fmt.Fprintln(b, "}\n")
  282. // gofmt
  283. res, err := format.Source(b.Bytes())
  284. if err != nil {
  285. b.WriteTo(os.Stderr)
  286. log.Fatal(err)
  287. }
  288. // write result
  289. f, err := os.Create("zmsg.go")
  290. fatalIfErr(err)
  291. defer f.Close()
  292. f.Write(res)
  293. }
  294. // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
  295. func structMember(s string) string {
  296. fields := strings.Split(s, ":")
  297. if len(fields) == 0 {
  298. return ""
  299. }
  300. f := fields[len(fields)-1]
  301. // f should have a closing "
  302. if len(f) > 1 {
  303. return f[:len(f)-1]
  304. }
  305. return f
  306. }
  307. // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
  308. func structTag(s string) string {
  309. fields := strings.Split(s, ":")
  310. if len(fields) < 2 {
  311. return ""
  312. }
  313. return fields[1][len("\"size-"):]
  314. }
  315. func fatalIfErr(err error) {
  316. if err != nil {
  317. log.Fatal(err)
  318. }
  319. }