msg_generate.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. //+build ignore
  2. // msg_generate.go is meant to run with go generate. It will use
  3. // go/{importer,types} to track down all the RR struct types. Then for each type
  4. // it will generate pack/unpack methods based on the struct tags. The generated source is
  5. // written to zmsg.go, and is meant to be checked into git.
  6. package main
  7. import (
  8. "bytes"
  9. "fmt"
  10. "go/format"
  11. "go/types"
  12. "log"
  13. "os"
  14. "strings"
  15. "golang.org/x/tools/go/packages"
  16. )
  17. var packageHdr = `
  18. // Code generated by "go run msg_generate.go"; DO NOT EDIT.
  19. package dns
  20. `
  21. // getTypeStruct will take a type and the package scope, and return the
  22. // (innermost) struct if the type is considered a RR type (currently defined as
  23. // those structs beginning with a RR_Header, could be redefined as implementing
  24. // the RR interface). The bool return value indicates if embedded structs were
  25. // resolved.
  26. func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
  27. st, ok := t.Underlying().(*types.Struct)
  28. if !ok {
  29. return nil, false
  30. }
  31. if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
  32. return st, false
  33. }
  34. if st.Field(0).Anonymous() {
  35. st, _ := getTypeStruct(st.Field(0).Type(), scope)
  36. return st, true
  37. }
  38. return nil, false
  39. }
  40. // loadModule retrieves package description for a given module.
  41. func loadModule(name string) (*types.Package, error) {
  42. conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo}
  43. pkgs, err := packages.Load(&conf, name)
  44. if err != nil {
  45. return nil, err
  46. }
  47. return pkgs[0].Types, nil
  48. }
  49. func main() {
  50. // Import and type-check the package
  51. pkg, err := loadModule("github.com/miekg/dns")
  52. fatalIfErr(err)
  53. scope := pkg.Scope()
  54. // Collect actual types (*X)
  55. var namedTypes []string
  56. for _, name := range scope.Names() {
  57. o := scope.Lookup(name)
  58. if o == nil || !o.Exported() {
  59. continue
  60. }
  61. if st, _ := getTypeStruct(o.Type(), scope); st == nil {
  62. continue
  63. }
  64. if name == "PrivateRR" {
  65. continue
  66. }
  67. // Check if corresponding TypeX exists
  68. if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
  69. log.Fatalf("Constant Type%s does not exist.", o.Name())
  70. }
  71. namedTypes = append(namedTypes, o.Name())
  72. }
  73. b := &bytes.Buffer{}
  74. b.WriteString(packageHdr)
  75. fmt.Fprint(b, "// pack*() functions\n\n")
  76. for _, name := range namedTypes {
  77. o := scope.Lookup(name)
  78. st, _ := getTypeStruct(o.Type(), scope)
  79. fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name)
  80. for i := 1; i < st.NumFields(); i++ {
  81. o := func(s string) {
  82. fmt.Fprintf(b, s, st.Field(i).Name())
  83. fmt.Fprint(b, `if err != nil {
  84. return off, err
  85. }
  86. `)
  87. }
  88. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  89. switch st.Tag(i) {
  90. case `dns:"-"`: // ignored
  91. case `dns:"txt"`:
  92. o("off, err = packStringTxt(rr.%s, msg, off)\n")
  93. case `dns:"opt"`:
  94. o("off, err = packDataOpt(rr.%s, msg, off)\n")
  95. case `dns:"nsec"`:
  96. o("off, err = packDataNsec(rr.%s, msg, off)\n")
  97. case `dns:"domain-name"`:
  98. o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
  99. case `dns:"apl"`:
  100. o("off, err = packDataApl(rr.%s, msg, off)\n")
  101. default:
  102. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  103. }
  104. continue
  105. }
  106. switch {
  107. case st.Tag(i) == `dns:"-"`: // ignored
  108. case st.Tag(i) == `dns:"cdomain-name"`:
  109. o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
  110. case st.Tag(i) == `dns:"domain-name"`:
  111. o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n")
  112. case st.Tag(i) == `dns:"a"`:
  113. o("off, err = packDataA(rr.%s, msg, off)\n")
  114. case st.Tag(i) == `dns:"aaaa"`:
  115. o("off, err = packDataAAAA(rr.%s, msg, off)\n")
  116. case st.Tag(i) == `dns:"uint48"`:
  117. o("off, err = packUint48(rr.%s, msg, off)\n")
  118. case st.Tag(i) == `dns:"txt"`:
  119. o("off, err = packString(rr.%s, msg, off)\n")
  120. case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
  121. fallthrough
  122. case st.Tag(i) == `dns:"base32"`:
  123. o("off, err = packStringBase32(rr.%s, msg, off)\n")
  124. case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
  125. fallthrough
  126. case st.Tag(i) == `dns:"base64"`:
  127. o("off, err = packStringBase64(rr.%s, msg, off)\n")
  128. case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
  129. // directly write instead of using o() so we get the error check in the correct place
  130. field := st.Field(i).Name()
  131. fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
  132. if rr.%s != "-" {
  133. off, err = packStringHex(rr.%s, msg, off)
  134. if err != nil {
  135. return off, err
  136. }
  137. }
  138. `, field, field)
  139. continue
  140. case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
  141. fallthrough
  142. case st.Tag(i) == `dns:"hex"`:
  143. o("off, err = packStringHex(rr.%s, msg, off)\n")
  144. case st.Tag(i) == `dns:"any"`:
  145. o("off, err = packStringAny(rr.%s, msg, off)\n")
  146. case st.Tag(i) == `dns:"octet"`:
  147. o("off, err = packStringOctet(rr.%s, msg, off)\n")
  148. case st.Tag(i) == "":
  149. switch st.Field(i).Type().(*types.Basic).Kind() {
  150. case types.Uint8:
  151. o("off, err = packUint8(rr.%s, msg, off)\n")
  152. case types.Uint16:
  153. o("off, err = packUint16(rr.%s, msg, off)\n")
  154. case types.Uint32:
  155. o("off, err = packUint32(rr.%s, msg, off)\n")
  156. case types.Uint64:
  157. o("off, err = packUint64(rr.%s, msg, off)\n")
  158. case types.String:
  159. o("off, err = packString(rr.%s, msg, off)\n")
  160. default:
  161. log.Fatalln(name, st.Field(i).Name())
  162. }
  163. default:
  164. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  165. }
  166. }
  167. fmt.Fprintln(b, "return off, nil }\n")
  168. }
  169. fmt.Fprint(b, "// unpack*() functions\n\n")
  170. for _, name := range namedTypes {
  171. o := scope.Lookup(name)
  172. st, _ := getTypeStruct(o.Type(), scope)
  173. fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name)
  174. fmt.Fprint(b, `rdStart := off
  175. _ = rdStart
  176. `)
  177. for i := 1; i < st.NumFields(); i++ {
  178. o := func(s string) {
  179. fmt.Fprintf(b, s, st.Field(i).Name())
  180. fmt.Fprint(b, `if err != nil {
  181. return off, err
  182. }
  183. `)
  184. }
  185. // size-* are special, because they reference a struct member we should use for the length.
  186. if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
  187. structMember := structMember(st.Tag(i))
  188. structTag := structTag(st.Tag(i))
  189. switch structTag {
  190. case "hex":
  191. fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  192. case "base32":
  193. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  194. case "base64":
  195. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  196. default:
  197. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  198. }
  199. fmt.Fprint(b, `if err != nil {
  200. return off, err
  201. }
  202. `)
  203. continue
  204. }
  205. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  206. switch st.Tag(i) {
  207. case `dns:"-"`: // ignored
  208. case `dns:"txt"`:
  209. o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
  210. case `dns:"opt"`:
  211. o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
  212. case `dns:"nsec"`:
  213. o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
  214. case `dns:"domain-name"`:
  215. o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  216. case `dns:"apl"`:
  217. o("rr.%s, off, err = unpackDataApl(msg, off)\n")
  218. default:
  219. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  220. }
  221. continue
  222. }
  223. switch st.Tag(i) {
  224. case `dns:"-"`: // ignored
  225. case `dns:"cdomain-name"`:
  226. fallthrough
  227. case `dns:"domain-name"`:
  228. o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
  229. case `dns:"a"`:
  230. o("rr.%s, off, err = unpackDataA(msg, off)\n")
  231. case `dns:"aaaa"`:
  232. o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
  233. case `dns:"uint48"`:
  234. o("rr.%s, off, err = unpackUint48(msg, off)\n")
  235. case `dns:"txt"`:
  236. o("rr.%s, off, err = unpackString(msg, off)\n")
  237. case `dns:"base32"`:
  238. o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  239. case `dns:"base64"`:
  240. o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  241. case `dns:"hex"`:
  242. o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  243. case `dns:"any"`:
  244. o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  245. case `dns:"octet"`:
  246. o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
  247. case "":
  248. switch st.Field(i).Type().(*types.Basic).Kind() {
  249. case types.Uint8:
  250. o("rr.%s, off, err = unpackUint8(msg, off)\n")
  251. case types.Uint16:
  252. o("rr.%s, off, err = unpackUint16(msg, off)\n")
  253. case types.Uint32:
  254. o("rr.%s, off, err = unpackUint32(msg, off)\n")
  255. case types.Uint64:
  256. o("rr.%s, off, err = unpackUint64(msg, off)\n")
  257. case types.String:
  258. o("rr.%s, off, err = unpackString(msg, off)\n")
  259. default:
  260. log.Fatalln(name, st.Field(i).Name())
  261. }
  262. default:
  263. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  264. }
  265. // If we've hit len(msg) we return without error.
  266. if i < st.NumFields()-1 {
  267. fmt.Fprintf(b, `if off == len(msg) {
  268. return off, nil
  269. }
  270. `)
  271. }
  272. }
  273. fmt.Fprintf(b, "return off, nil }\n\n")
  274. }
  275. // gofmt
  276. res, err := format.Source(b.Bytes())
  277. if err != nil {
  278. b.WriteTo(os.Stderr)
  279. log.Fatal(err)
  280. }
  281. // write result
  282. f, err := os.Create("zmsg.go")
  283. fatalIfErr(err)
  284. defer f.Close()
  285. f.Write(res)
  286. }
  287. // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
  288. func structMember(s string) string {
  289. fields := strings.Split(s, ":")
  290. if len(fields) == 0 {
  291. return ""
  292. }
  293. f := fields[len(fields)-1]
  294. // f should have a closing "
  295. if len(f) > 1 {
  296. return f[:len(f)-1]
  297. }
  298. return f
  299. }
  300. // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
  301. func structTag(s string) string {
  302. fields := strings.Split(s, ":")
  303. if len(fields) < 2 {
  304. return ""
  305. }
  306. return fields[1][len("\"size-"):]
  307. }
  308. func fatalIfErr(err error) {
  309. if err != nil {
  310. log.Fatal(err)
  311. }
  312. }