msg_generate.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. //+build ignore
  2. // msg_generate.go is meant to run with go generate. It will use
  3. // go/{importer,types} to track down all the RR struct types. Then for each type
  4. // it will generate pack/unpack methods based on the struct tags. The generated source is
  5. // written to zmsg.go, and is meant to be checked into git.
  6. package main
  7. import (
  8. "bytes"
  9. "fmt"
  10. "go/format"
  11. "go/types"
  12. "log"
  13. "os"
  14. "strings"
  15. "golang.org/x/tools/go/packages"
  16. )
  17. var packageHdr = `
  18. // Code generated by "go run msg_generate.go"; DO NOT EDIT.
  19. package dns
  20. `
  21. // getTypeStruct will take a type and the package scope, and return the
  22. // (innermost) struct if the type is considered a RR type (currently defined as
  23. // those structs beginning with a RR_Header, could be redefined as implementing
  24. // the RR interface). The bool return value indicates if embedded structs were
  25. // resolved.
  26. func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
  27. st, ok := t.Underlying().(*types.Struct)
  28. if !ok {
  29. return nil, false
  30. }
  31. if st.NumFields() == 0 {
  32. return nil, false
  33. }
  34. if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
  35. return st, false
  36. }
  37. if st.Field(0).Anonymous() {
  38. st, _ := getTypeStruct(st.Field(0).Type(), scope)
  39. return st, true
  40. }
  41. return nil, false
  42. }
  43. // loadModule retrieves package description for a given module.
  44. func loadModule(name string) (*types.Package, error) {
  45. conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo}
  46. pkgs, err := packages.Load(&conf, name)
  47. if err != nil {
  48. return nil, err
  49. }
  50. return pkgs[0].Types, nil
  51. }
  52. func main() {
  53. // Import and type-check the package
  54. pkg, err := loadModule("github.com/miekg/dns")
  55. fatalIfErr(err)
  56. scope := pkg.Scope()
  57. // Collect actual types (*X)
  58. var namedTypes []string
  59. for _, name := range scope.Names() {
  60. o := scope.Lookup(name)
  61. if o == nil || !o.Exported() {
  62. continue
  63. }
  64. if st, _ := getTypeStruct(o.Type(), scope); st == nil {
  65. continue
  66. }
  67. if name == "PrivateRR" {
  68. continue
  69. }
  70. // Check if corresponding TypeX exists
  71. if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
  72. log.Fatalf("Constant Type%s does not exist.", o.Name())
  73. }
  74. namedTypes = append(namedTypes, o.Name())
  75. }
  76. b := &bytes.Buffer{}
  77. b.WriteString(packageHdr)
  78. fmt.Fprint(b, "// pack*() functions\n\n")
  79. for _, name := range namedTypes {
  80. o := scope.Lookup(name)
  81. st, _ := getTypeStruct(o.Type(), scope)
  82. fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name)
  83. for i := 1; i < st.NumFields(); i++ {
  84. o := func(s string) {
  85. fmt.Fprintf(b, s, st.Field(i).Name())
  86. fmt.Fprint(b, `if err != nil {
  87. return off, err
  88. }
  89. `)
  90. }
  91. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  92. switch st.Tag(i) {
  93. case `dns:"-"`: // ignored
  94. case `dns:"txt"`:
  95. o("off, err = packStringTxt(rr.%s, msg, off)\n")
  96. case `dns:"opt"`:
  97. o("off, err = packDataOpt(rr.%s, msg, off)\n")
  98. case `dns:"nsec"`:
  99. o("off, err = packDataNsec(rr.%s, msg, off)\n")
  100. case `dns:"pairs"`:
  101. o("off, err = packDataSVCB(rr.%s, msg, off)\n")
  102. case `dns:"domain-name"`:
  103. o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
  104. case `dns:"apl"`:
  105. o("off, err = packDataApl(rr.%s, msg, off)\n")
  106. default:
  107. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  108. }
  109. continue
  110. }
  111. switch {
  112. case st.Tag(i) == `dns:"-"`: // ignored
  113. case st.Tag(i) == `dns:"cdomain-name"`:
  114. o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
  115. case st.Tag(i) == `dns:"domain-name"`:
  116. o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n")
  117. case st.Tag(i) == `dns:"a"`:
  118. o("off, err = packDataA(rr.%s, msg, off)\n")
  119. case st.Tag(i) == `dns:"aaaa"`:
  120. o("off, err = packDataAAAA(rr.%s, msg, off)\n")
  121. case st.Tag(i) == `dns:"uint48"`:
  122. o("off, err = packUint48(rr.%s, msg, off)\n")
  123. case st.Tag(i) == `dns:"txt"`:
  124. o("off, err = packString(rr.%s, msg, off)\n")
  125. case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
  126. fallthrough
  127. case st.Tag(i) == `dns:"base32"`:
  128. o("off, err = packStringBase32(rr.%s, msg, off)\n")
  129. case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
  130. fallthrough
  131. case st.Tag(i) == `dns:"base64"`:
  132. o("off, err = packStringBase64(rr.%s, msg, off)\n")
  133. case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
  134. // directly write instead of using o() so we get the error check in the correct place
  135. field := st.Field(i).Name()
  136. fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
  137. if rr.%s != "-" {
  138. off, err = packStringHex(rr.%s, msg, off)
  139. if err != nil {
  140. return off, err
  141. }
  142. }
  143. `, field, field)
  144. continue
  145. case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
  146. fallthrough
  147. case st.Tag(i) == `dns:"hex"`:
  148. o("off, err = packStringHex(rr.%s, msg, off)\n")
  149. case st.Tag(i) == `dns:"any"`:
  150. o("off, err = packStringAny(rr.%s, msg, off)\n")
  151. case st.Tag(i) == `dns:"octet"`:
  152. o("off, err = packStringOctet(rr.%s, msg, off)\n")
  153. case st.Tag(i) == "":
  154. switch st.Field(i).Type().(*types.Basic).Kind() {
  155. case types.Uint8:
  156. o("off, err = packUint8(rr.%s, msg, off)\n")
  157. case types.Uint16:
  158. o("off, err = packUint16(rr.%s, msg, off)\n")
  159. case types.Uint32:
  160. o("off, err = packUint32(rr.%s, msg, off)\n")
  161. case types.Uint64:
  162. o("off, err = packUint64(rr.%s, msg, off)\n")
  163. case types.String:
  164. o("off, err = packString(rr.%s, msg, off)\n")
  165. default:
  166. log.Fatalln(name, st.Field(i).Name())
  167. }
  168. default:
  169. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  170. }
  171. }
  172. fmt.Fprintln(b, "return off, nil }\n")
  173. }
  174. fmt.Fprint(b, "// unpack*() functions\n\n")
  175. for _, name := range namedTypes {
  176. o := scope.Lookup(name)
  177. st, _ := getTypeStruct(o.Type(), scope)
  178. fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name)
  179. fmt.Fprint(b, `rdStart := off
  180. _ = rdStart
  181. `)
  182. for i := 1; i < st.NumFields(); i++ {
  183. o := func(s string) {
  184. fmt.Fprintf(b, s, st.Field(i).Name())
  185. fmt.Fprint(b, `if err != nil {
  186. return off, err
  187. }
  188. `)
  189. }
  190. // size-* are special, because they reference a struct member we should use for the length.
  191. if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
  192. structMember := structMember(st.Tag(i))
  193. structTag := structTag(st.Tag(i))
  194. switch structTag {
  195. case "hex":
  196. fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  197. case "base32":
  198. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  199. case "base64":
  200. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  201. default:
  202. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  203. }
  204. fmt.Fprint(b, `if err != nil {
  205. return off, err
  206. }
  207. `)
  208. continue
  209. }
  210. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  211. switch st.Tag(i) {
  212. case `dns:"-"`: // ignored
  213. case `dns:"txt"`:
  214. o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
  215. case `dns:"opt"`:
  216. o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
  217. case `dns:"nsec"`:
  218. o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
  219. case `dns:"pairs"`:
  220. o("rr.%s, off, err = unpackDataSVCB(msg, off)\n")
  221. case `dns:"domain-name"`:
  222. o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  223. case `dns:"apl"`:
  224. o("rr.%s, off, err = unpackDataApl(msg, off)\n")
  225. default:
  226. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  227. }
  228. continue
  229. }
  230. switch st.Tag(i) {
  231. case `dns:"-"`: // ignored
  232. case `dns:"cdomain-name"`:
  233. fallthrough
  234. case `dns:"domain-name"`:
  235. o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
  236. case `dns:"a"`:
  237. o("rr.%s, off, err = unpackDataA(msg, off)\n")
  238. case `dns:"aaaa"`:
  239. o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
  240. case `dns:"uint48"`:
  241. o("rr.%s, off, err = unpackUint48(msg, off)\n")
  242. case `dns:"txt"`:
  243. o("rr.%s, off, err = unpackString(msg, off)\n")
  244. case `dns:"base32"`:
  245. o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  246. case `dns:"base64"`:
  247. o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  248. case `dns:"hex"`:
  249. o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  250. case `dns:"any"`:
  251. o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  252. case `dns:"octet"`:
  253. o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
  254. case "":
  255. switch st.Field(i).Type().(*types.Basic).Kind() {
  256. case types.Uint8:
  257. o("rr.%s, off, err = unpackUint8(msg, off)\n")
  258. case types.Uint16:
  259. o("rr.%s, off, err = unpackUint16(msg, off)\n")
  260. case types.Uint32:
  261. o("rr.%s, off, err = unpackUint32(msg, off)\n")
  262. case types.Uint64:
  263. o("rr.%s, off, err = unpackUint64(msg, off)\n")
  264. case types.String:
  265. o("rr.%s, off, err = unpackString(msg, off)\n")
  266. default:
  267. log.Fatalln(name, st.Field(i).Name())
  268. }
  269. default:
  270. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  271. }
  272. // If we've hit len(msg) we return without error.
  273. if i < st.NumFields()-1 {
  274. fmt.Fprintf(b, `if off == len(msg) {
  275. return off, nil
  276. }
  277. `)
  278. }
  279. }
  280. fmt.Fprintf(b, "return off, nil }\n\n")
  281. }
  282. // gofmt
  283. res, err := format.Source(b.Bytes())
  284. if err != nil {
  285. b.WriteTo(os.Stderr)
  286. log.Fatal(err)
  287. }
  288. // write result
  289. f, err := os.Create("zmsg.go")
  290. fatalIfErr(err)
  291. defer f.Close()
  292. f.Write(res)
  293. }
  294. // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
  295. func structMember(s string) string {
  296. fields := strings.Split(s, ":")
  297. if len(fields) == 0 {
  298. return ""
  299. }
  300. f := fields[len(fields)-1]
  301. // f should have a closing "
  302. if len(f) > 1 {
  303. return f[:len(f)-1]
  304. }
  305. return f
  306. }
  307. // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
  308. func structTag(s string) string {
  309. fields := strings.Split(s, ":")
  310. if len(fields) < 2 {
  311. return ""
  312. }
  313. return fields[1][len("\"size-"):]
  314. }
  315. func fatalIfErr(err error) {
  316. if err != nil {
  317. log.Fatal(err)
  318. }
  319. }