duplicate_generate.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. //+build ignore
  2. // types_generate.go is meant to run with go generate. It will use
  3. // go/{importer,types} to track down all the RR struct types. Then for each type
  4. // it will generate conversion tables (TypeToRR and TypeToString) and banal
  5. // methods (len, Header, copy) based on the struct tags. The generated source is
  6. // written to ztypes.go, and is meant to be checked into git.
  7. package main
  8. import (
  9. "bytes"
  10. "fmt"
  11. "go/format"
  12. "go/types"
  13. "log"
  14. "os"
  15. "golang.org/x/tools/go/packages"
  16. )
  17. var packageHdr = `
  18. // Code generated by "go run duplicate_generate.go"; DO NOT EDIT.
  19. package dns
  20. `
  21. func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
  22. st, ok := t.Underlying().(*types.Struct)
  23. if !ok {
  24. return nil, false
  25. }
  26. if st.NumFields() == 0 {
  27. return nil, false
  28. }
  29. if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
  30. return st, false
  31. }
  32. if st.Field(0).Anonymous() {
  33. st, _ := getTypeStruct(st.Field(0).Type(), scope)
  34. return st, true
  35. }
  36. return nil, false
  37. }
  38. // loadModule retrieves package description for a given module.
  39. func loadModule(name string) (*types.Package, error) {
  40. conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo}
  41. pkgs, err := packages.Load(&conf, name)
  42. if err != nil {
  43. return nil, err
  44. }
  45. return pkgs[0].Types, nil
  46. }
  47. func main() {
  48. // Import and type-check the package
  49. pkg, err := loadModule("github.com/miekg/dns")
  50. fatalIfErr(err)
  51. scope := pkg.Scope()
  52. // Collect actual types (*X)
  53. var namedTypes []string
  54. for _, name := range scope.Names() {
  55. o := scope.Lookup(name)
  56. if o == nil || !o.Exported() {
  57. continue
  58. }
  59. if st, _ := getTypeStruct(o.Type(), scope); st == nil {
  60. continue
  61. }
  62. if name == "PrivateRR" || name == "OPT" {
  63. continue
  64. }
  65. namedTypes = append(namedTypes, o.Name())
  66. }
  67. b := &bytes.Buffer{}
  68. b.WriteString(packageHdr)
  69. // Generate the duplicate check for each type.
  70. fmt.Fprint(b, "// isDuplicate() functions\n\n")
  71. for _, name := range namedTypes {
  72. o := scope.Lookup(name)
  73. st, _ := getTypeStruct(o.Type(), scope)
  74. fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name)
  75. fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name)
  76. fmt.Fprint(b, "if !ok { return false }\n")
  77. fmt.Fprint(b, "_ = r2\n")
  78. for i := 1; i < st.NumFields(); i++ {
  79. field := st.Field(i).Name()
  80. o2 := func(s string) { fmt.Fprintf(b, s+"\n", field, field) }
  81. o3 := func(s string) { fmt.Fprintf(b, s+"\n", field, field, field) }
  82. // For some reason, a and aaaa don't pop up as *types.Slice here (mostly like because the are
  83. // *indirectly* defined as a slice in the net package).
  84. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  85. o2("if len(r1.%s) != len(r2.%s) {\nreturn false\n}")
  86. if st.Tag(i) == `dns:"cdomain-name"` || st.Tag(i) == `dns:"domain-name"` {
  87. o3(`for i := 0; i < len(r1.%s); i++ {
  88. if !isDuplicateName(r1.%s[i], r2.%s[i]) {
  89. return false
  90. }
  91. }`)
  92. continue
  93. }
  94. if st.Tag(i) == `dns:"apl"` {
  95. o3(`for i := 0; i < len(r1.%s); i++ {
  96. if !r1.%s[i].equals(&r2.%s[i]) {
  97. return false
  98. }
  99. }`)
  100. continue
  101. }
  102. if st.Tag(i) == `dns:"pairs"` {
  103. o2(`if !areSVCBPairArraysEqual(r1.%s, r2.%s) {
  104. return false
  105. }`)
  106. continue
  107. }
  108. o3(`for i := 0; i < len(r1.%s); i++ {
  109. if r1.%s[i] != r2.%s[i] {
  110. return false
  111. }
  112. }`)
  113. continue
  114. }
  115. switch st.Tag(i) {
  116. case `dns:"-"`:
  117. // ignored
  118. case `dns:"a"`, `dns:"aaaa"`:
  119. o2("if !r1.%s.Equal(r2.%s) {\nreturn false\n}")
  120. case `dns:"cdomain-name"`, `dns:"domain-name"`:
  121. o2("if !isDuplicateName(r1.%s, r2.%s) {\nreturn false\n}")
  122. default:
  123. o2("if r1.%s != r2.%s {\nreturn false\n}")
  124. }
  125. }
  126. fmt.Fprintf(b, "return true\n}\n\n")
  127. }
  128. // gofmt
  129. res, err := format.Source(b.Bytes())
  130. if err != nil {
  131. b.WriteTo(os.Stderr)
  132. log.Fatal(err)
  133. }
  134. // write result
  135. f, err := os.Create("zduplicate.go")
  136. fatalIfErr(err)
  137. defer f.Close()
  138. f.Write(res)
  139. }
  140. func fatalIfErr(err error) {
  141. if err != nil {
  142. log.Fatal(err)
  143. }
  144. }