xfr.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. package dns
  2. import (
  3. "fmt"
  4. "time"
  5. )
  6. // Envelope is used when doing a zone transfer with a remote server.
  7. type Envelope struct {
  8. RR []RR // The set of RRs in the answer section of the xfr reply message.
  9. Error error // If something went wrong, this contains the error.
  10. }
  11. // A Transfer defines parameters that are used during a zone transfer.
  12. type Transfer struct {
  13. *Conn
  14. DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
  15. ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
  16. WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
  17. TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
  18. TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
  19. tsigTimersOnly bool
  20. }
  21. func (t *Transfer) tsigProvider() TsigProvider {
  22. if t.TsigProvider != nil {
  23. return t.TsigProvider
  24. }
  25. if t.TsigSecret != nil {
  26. return tsigSecretProvider(t.TsigSecret)
  27. }
  28. return nil
  29. }
  30. // TODO: Think we need to away to stop the transfer
  31. // In performs an incoming transfer with the server in a.
  32. // If you would like to set the source IP, or some other attribute
  33. // of a Dialer for a Transfer, you can do so by specifying the attributes
  34. // in the Transfer.Conn:
  35. //
  36. // d := net.Dialer{LocalAddr: transfer_source}
  37. // con, err := d.Dial("tcp", master)
  38. // dnscon := &dns.Conn{Conn:con}
  39. // transfer = &dns.Transfer{Conn: dnscon}
  40. // channel, err := transfer.In(message, master)
  41. func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
  42. switch q.Question[0].Qtype {
  43. case TypeAXFR, TypeIXFR:
  44. default:
  45. return nil, &Error{"unsupported question type"}
  46. }
  47. timeout := dnsTimeout
  48. if t.DialTimeout != 0 {
  49. timeout = t.DialTimeout
  50. }
  51. if t.Conn == nil {
  52. t.Conn, err = DialTimeout("tcp", a, timeout)
  53. if err != nil {
  54. return nil, err
  55. }
  56. }
  57. if err := t.WriteMsg(q); err != nil {
  58. return nil, err
  59. }
  60. env = make(chan *Envelope)
  61. switch q.Question[0].Qtype {
  62. case TypeAXFR:
  63. go t.inAxfr(q, env)
  64. case TypeIXFR:
  65. go t.inIxfr(q, env)
  66. }
  67. return env, nil
  68. }
  69. func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) {
  70. first := true
  71. defer t.Close()
  72. defer close(c)
  73. timeout := dnsTimeout
  74. if t.ReadTimeout != 0 {
  75. timeout = t.ReadTimeout
  76. }
  77. for {
  78. t.Conn.SetReadDeadline(time.Now().Add(timeout))
  79. in, err := t.ReadMsg()
  80. if err != nil {
  81. c <- &Envelope{nil, err}
  82. return
  83. }
  84. if q.Id != in.Id {
  85. c <- &Envelope{in.Answer, ErrId}
  86. return
  87. }
  88. if first {
  89. if in.Rcode != RcodeSuccess {
  90. c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
  91. return
  92. }
  93. if !isSOAFirst(in) {
  94. c <- &Envelope{in.Answer, ErrSoa}
  95. return
  96. }
  97. first = !first
  98. // only one answer that is SOA, receive more
  99. if len(in.Answer) == 1 {
  100. t.tsigTimersOnly = true
  101. c <- &Envelope{in.Answer, nil}
  102. continue
  103. }
  104. }
  105. if !first {
  106. t.tsigTimersOnly = true // Subsequent envelopes use this.
  107. if isSOALast(in) {
  108. c <- &Envelope{in.Answer, nil}
  109. return
  110. }
  111. c <- &Envelope{in.Answer, nil}
  112. }
  113. }
  114. }
  115. func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
  116. var serial uint32 // The first serial seen is the current server serial
  117. axfr := true
  118. n := 0
  119. qser := q.Ns[0].(*SOA).Serial
  120. defer t.Close()
  121. defer close(c)
  122. timeout := dnsTimeout
  123. if t.ReadTimeout != 0 {
  124. timeout = t.ReadTimeout
  125. }
  126. for {
  127. t.SetReadDeadline(time.Now().Add(timeout))
  128. in, err := t.ReadMsg()
  129. if err != nil {
  130. c <- &Envelope{nil, err}
  131. return
  132. }
  133. if q.Id != in.Id {
  134. c <- &Envelope{in.Answer, ErrId}
  135. return
  136. }
  137. if in.Rcode != RcodeSuccess {
  138. c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
  139. return
  140. }
  141. if n == 0 {
  142. // Check if the returned answer is ok
  143. if !isSOAFirst(in) {
  144. c <- &Envelope{in.Answer, ErrSoa}
  145. return
  146. }
  147. // This serial is important
  148. serial = in.Answer[0].(*SOA).Serial
  149. // Check if there are no changes in zone
  150. if qser >= serial {
  151. c <- &Envelope{in.Answer, nil}
  152. return
  153. }
  154. }
  155. // Now we need to check each message for SOA records, to see what we need to do
  156. t.tsigTimersOnly = true
  157. for _, rr := range in.Answer {
  158. if v, ok := rr.(*SOA); ok {
  159. if v.Serial == serial {
  160. n++
  161. // quit if it's a full axfr or the the servers' SOA is repeated the third time
  162. if axfr && n == 2 || n == 3 {
  163. c <- &Envelope{in.Answer, nil}
  164. return
  165. }
  166. } else if axfr {
  167. // it's an ixfr
  168. axfr = false
  169. }
  170. }
  171. }
  172. c <- &Envelope{in.Answer, nil}
  173. }
  174. }
  175. // Out performs an outgoing transfer with the client connecting in w.
  176. // Basic use pattern:
  177. //
  178. // ch := make(chan *dns.Envelope)
  179. // tr := new(dns.Transfer)
  180. // var wg sync.WaitGroup
  181. // go func() {
  182. // tr.Out(w, r, ch)
  183. // wg.Done()
  184. // }()
  185. // ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}}
  186. // close(ch)
  187. // wg.Wait() // wait until everything is written out
  188. // w.Close() // close connection
  189. //
  190. // The server is responsible for sending the correct sequence of RRs through the channel ch.
  191. func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
  192. for x := range ch {
  193. r := new(Msg)
  194. // Compress?
  195. r.SetReply(q)
  196. r.Authoritative = true
  197. // assume it fits TODO(miek): fix
  198. r.Answer = append(r.Answer, x.RR...)
  199. if tsig := q.IsTsig(); tsig != nil && w.TsigStatus() == nil {
  200. r.SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix())
  201. }
  202. if err := w.WriteMsg(r); err != nil {
  203. return err
  204. }
  205. w.TsigTimersOnly(true)
  206. }
  207. return nil
  208. }
  209. // ReadMsg reads a message from the transfer connection t.
  210. func (t *Transfer) ReadMsg() (*Msg, error) {
  211. m := new(Msg)
  212. p := make([]byte, MaxMsgSize)
  213. n, err := t.Read(p)
  214. if err != nil && n == 0 {
  215. return nil, err
  216. }
  217. p = p[:n]
  218. if err := m.Unpack(p); err != nil {
  219. return nil, err
  220. }
  221. if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
  222. // Need to work on the original message p, as that was used to calculate the tsig.
  223. err = TsigVerifyWithProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
  224. t.tsigRequestMAC = ts.MAC
  225. }
  226. return m, err
  227. }
  228. // WriteMsg writes a message through the transfer connection t.
  229. func (t *Transfer) WriteMsg(m *Msg) (err error) {
  230. var out []byte
  231. if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
  232. out, t.tsigRequestMAC, err = TsigGenerateWithProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
  233. } else {
  234. out, err = m.Pack()
  235. }
  236. if err != nil {
  237. return err
  238. }
  239. _, err = t.Write(out)
  240. return err
  241. }
  242. func isSOAFirst(in *Msg) bool {
  243. return len(in.Answer) > 0 &&
  244. in.Answer[0].Header().Rrtype == TypeSOA
  245. }
  246. func isSOALast(in *Msg) bool {
  247. return len(in.Answer) > 0 &&
  248. in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
  249. }
  250. const errXFR = "bad xfr rcode: %d"