Browse Source

Merge branch 'master' into quic-enhancements

Rod Hynes 4 years ago
parent
commit
f5d39d5185
89 changed files with 4081 additions and 1012 deletions
  1. 134 0
      .github/workflows/tests.yml
  2. 0 77
      .travis.yml
  3. 1 1
      MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml
  4. 2 1
      README.md
  5. 9 0
      psiphon/common/osl/osl.go
  6. 49 0
      psiphon/common/parameters/parameters.go
  7. 5 0
      psiphon/common/parameters/parameters_test.go
  8. 41 0
      psiphon/common/parameters/portlist.go
  9. 196 0
      psiphon/common/portlist.go
  10. 222 0
      psiphon/common/portlist_test.go
  11. 82 12
      psiphon/common/protocol/serverEntry.go
  12. 18 4
      psiphon/config.go
  13. 34 25
      psiphon/controller.go
  14. BIN
      psiphon/controller_test.config.enc
  15. 6 1
      psiphon/dataStore.go
  16. 29 40
      psiphon/dialParameters.go
  17. 116 11
      psiphon/dialParameters_test.go
  18. BIN
      psiphon/feedback_test.config.enc
  19. 6 7
      psiphon/feedback_test.go
  20. 12 0
      psiphon/meekConn.go
  21. 3 0
      psiphon/net.go
  22. 16 8
      psiphon/notice.go
  23. 6 30
      psiphon/server/listener.go
  24. 2 36
      psiphon/server/listener_test.go
  25. 66 24
      psiphon/server/meek.go
  26. 75 7
      psiphon/server/meek_test.go
  27. 1 1
      psiphon/server/psinet/psinet.go
  28. 76 6
      psiphon/server/server_test.go
  29. 18 93
      psiphon/server/trafficRules.go
  30. 40 18
      psiphon/server/tunnelServer.go
  31. 105 54
      psiphon/server/udp.go
  32. 5 0
      psiphon/upstreamproxy/go-ntlm/README.md
  33. 23 5
      psiphon/upstreamproxy/go-ntlm/ntlm/av_pairs.go
  34. 37 5
      psiphon/upstreamproxy/go-ntlm/ntlm/challenge_responses.go
  35. 38 4
      psiphon/upstreamproxy/go-ntlm/ntlm/message_authenticate.go
  36. 23 1
      psiphon/upstreamproxy/go-ntlm/ntlm/message_challenge.go
  37. 6 1
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1.go
  38. 5 5
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1_test.go
  39. 6 2
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2.go
  40. 1 1
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2_test.go
  41. 14 0
      psiphon/upstreamproxy/go-ntlm/ntlm/payload.go
  42. 7 0
      psiphon/upstreamproxy/go-ntlm/ntlm/version.go
  43. 7 1
      vendor/github.com/Psiphon-Labs/tls-tris/key_agreement.go
  44. 1 1
      vendor/github.com/miekg/dns/Makefile.release
  45. 14 6
      vendor/github.com/miekg/dns/README.md
  46. 1 0
      vendor/github.com/miekg/dns/acceptfunc.go
  47. 110 45
      vendor/github.com/miekg/dns/client.go
  48. 1 4
      vendor/github.com/miekg/dns/defaults.go
  49. 27 3
      vendor/github.com/miekg/dns/dns.go
  50. 18 47
      vendor/github.com/miekg/dns/dnssec.go
  51. 3 4
      vendor/github.com/miekg/dns/dnssec_keygen.go
  52. 4 17
      vendor/github.com/miekg/dns/dnssec_keyscan.go
  53. 3 20
      vendor/github.com/miekg/dns/dnssec_privkey.go
  54. 29 5
      vendor/github.com/miekg/dns/doc.go
  55. 12 4
      vendor/github.com/miekg/dns/duplicate_generate.go
  56. 151 5
      vendor/github.com/miekg/dns/edns.go
  57. 12 12
      vendor/github.com/miekg/dns/generate.go
  58. 4 6
      vendor/github.com/miekg/dns/go.mod
  59. 9 38
      vendor/github.com/miekg/dns/go.sum
  60. 1 1
      vendor/github.com/miekg/dns/labels.go
  61. 0 0
      vendor/github.com/miekg/dns/listen_no_reuseport.go
  62. 0 0
      vendor/github.com/miekg/dns/listen_reuseport.go
  63. 15 13
      vendor/github.com/miekg/dns/msg.go
  64. 7 0
      vendor/github.com/miekg/dns/msg_generate.go
  65. 73 82
      vendor/github.com/miekg/dns/msg_helpers.go
  66. 11 5
      vendor/github.com/miekg/dns/msg_truncate.go
  67. 2 2
      vendor/github.com/miekg/dns/privaterr.go
  68. 59 22
      vendor/github.com/miekg/dns/scan.go
  69. 53 18
      vendor/github.com/miekg/dns/scan_rr.go
  70. 2 2
      vendor/github.com/miekg/dns/serve_mux.go
  71. 92 28
      vendor/github.com/miekg/dns/server.go
  72. 3 15
      vendor/github.com/miekg/dns/sig0.go
  73. 755 0
      vendor/github.com/miekg/dns/svcb.go
  74. 99 59
      vendor/github.com/miekg/dns/tsig.go
  75. 79 50
      vendor/github.com/miekg/dns/types.go
  76. 21 5
      vendor/github.com/miekg/dns/types_generate.go
  77. 3 1
      vendor/github.com/miekg/dns/update.go
  78. 1 1
      vendor/github.com/miekg/dns/version.go
  79. 183 0
      vendor/github.com/miekg/dns/zduplicate.go
  80. 134 0
      vendor/github.com/miekg/dns/zmsg.go
  81. 56 2
      vendor/github.com/miekg/dns/ztypes.go
  82. 20 0
      vendor/github.com/refraction-networking/utls/README.md
  83. 5 1
      vendor/github.com/refraction-networking/utls/key_agreement.go
  84. 16 1
      vendor/github.com/refraction-networking/utls/ticket.go
  85. 15 0
      vendor/github.com/refraction-networking/utls/u_common.go
  86. 360 0
      vendor/github.com/refraction-networking/utls/u_fingerprinter.go
  87. 10 0
      vendor/github.com/refraction-networking/utls/u_parrots.go
  88. 59 0
      vendor/github.com/refraction-networking/utls/u_public.go
  89. 6 6
      vendor/vendor.json

+ 134 - 0
.github/workflows/tests.yml

@@ -0,0 +1,134 @@
+name: CI
+
+on:
+  workflow_dispatch:
+  push:
+    branches:
+      - master
+      - staging-client
+      - staging-server
+
+jobs:
+  run_tests:
+
+    strategy:
+      fail-fast: false
+      matrix:
+        os: [ "ubuntu" ]
+        go: [ "1.14.12" ]
+        test-type: [ "detector", "coverage", "memory" ]
+
+    runs-on: ${{ matrix.os }}-latest
+
+    name: psiphon-tunnel-core ${{ matrix.test-type }} tests on ${{ matrix.os}}, Go ${{ matrix.go }}
+
+    permissions:
+      checks: write
+      contents: read
+
+    env:
+      GOPATH: ${{ github.workspace }}/go
+
+    steps:
+
+      - name: Clone repository
+        uses: actions/checkout@v2
+        with:
+          path: ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+
+      - name: Install Go
+        uses: actions/setup-go@v2
+        with:
+          go-version: ${{ matrix.go }}
+
+      - name: Install networking components
+        run: |
+          sudo apt-get update
+          sudo apt-get install libnetfilter-queue-dev
+          sudo apt-get install conntrack
+
+      - name: Install coverage tools
+        if: ${{ matrix.test-type == 'coverage' }}
+        run: |
+          go get github.com/axw/gocov/gocov
+          go get github.com/modocache/gover
+          go get github.com/mattn/goveralls
+          go get golang.org/x/tools/cmd/cover
+
+      - name: Check environment
+        run: |
+          echo "GitHub workspace: $GITHUB_WORKSPACE"
+          echo "Working directory: `pwd`"
+          echo "GOROOT: $GOROOT"
+          echo "GOPATH: $GOPATH"
+          echo "Go version: `go version`"
+
+      - name: Pave config files
+        env:
+          CONTROLLER_TEST_CONFIG: ${{ secrets.CONTROLLER_TEST_CONFIG }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          echo "$CONTROLLER_TEST_CONFIG" > ./psiphon/controller_test.config
+
+      # TODO: fix and re-enable test
+      # sudo -E env "PATH=$PATH" go test -v -race ./psiphon/common/tun
+      - name: Run tests with data race detector
+        if: ${{ matrix.test-type == 'detector' }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          go test -v -race ./psiphon/common
+          go test -v -race ./psiphon/common/accesscontrol
+          go test -v -race ./psiphon/common/crypto/ssh
+          go test -v -race ./psiphon/common/fragmentor
+          go test -v -race ./psiphon/common/obfuscator
+          go test -v -race ./psiphon/common/osl
+          sudo -E env "PATH=$PATH" go test -v -race -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/common/packetman
+          go test -v -race ./psiphon/common/parameters
+          go test -v -race ./psiphon/common/protocol
+          go test -v -race ./psiphon/common/quic
+          go test -v -race ./psiphon/common/tactics
+          go test -v -race ./psiphon/common/values
+          go test -v -race ./psiphon/common/wildcard
+          go test -v -race ./psiphon/transferstats
+          sudo -E env "PATH=$PATH" go test -v -race -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/server
+          go test -v -race ./psiphon/server/psinet
+          go test -v -race ./psiphon
+          go test -v -race ./ClientLibrary/clientlib
+          go test -v -race ./Server/logging/analysis
+
+      # TODO: fix and re-enable test
+      # sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./psiphon/common/tun
+      - name: Run tests with coverage
+        env:
+          COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+        if: ${{ matrix.test-type == 'coverage' && github.repository == 'Psiphon-Labs/psiphon-tunnel-core' }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          go test -v -covermode=count -coverprofile=common.coverprofile ./psiphon/common
+          go test -v -covermode=count -coverprofile=accesscontrol.coverprofile ./psiphon/common/accesscontrol
+          go test -v -covermode=count -coverprofile=ssh.coverprofile ./psiphon/common/crypto/ssh
+          go test -v -covermode=count -coverprofile=fragmentor.coverprofile ./psiphon/common/fragmentor
+          go test -v -covermode=count -coverprofile=obfuscator.coverprofile ./psiphon/common/obfuscator
+          go test -v -covermode=count -coverprofile=osl.coverprofile ./psiphon/common/osl
+          sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=packetman.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/common/packetman
+          go test -v -covermode=count -coverprofile=parameters.coverprofile ./psiphon/common/parameters
+          go test -v -covermode=count -coverprofile=protocol.coverprofile ./psiphon/common/protocol
+          go test -v -covermode=count -coverprofile=quic.coverprofile ./psiphon/common/quic
+          go test -v -covermode=count -coverprofile=tactics.coverprofile ./psiphon/common/tactics
+          go test -v -covermode=count -coverprofile=values.coverprofile ./psiphon/common/values
+          go test -v -covermode=count -coverprofile=wildcard.coverprofile ./psiphon/common/wildcard
+          go test -v -covermode=count -coverprofile=transferstats.coverprofile ./psiphon/transferstats
+          sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=server.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/server
+          go test -v -covermode=count -coverprofile=psinet.coverprofile ./psiphon/server/psinet
+          go test -v -covermode=count -coverprofile=psiphon.coverprofile ./psiphon
+          go test -v -covermode=count -coverprofile=clientlib.coverprofile ./ClientLibrary/clientlib
+          go test -v -covermode=count -coverprofile=analysis.coverprofile ./Server/logging/analysis
+          $GOPATH/bin/gover
+          $GOPATH/bin/goveralls -coverprofile=gover.coverprofile -service=github -repotoken "$COVERALLS_TOKEN"
+
+      - name: Run memory tests
+        if: ${{ matrix.test-type == 'memory' }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          go test -v ./psiphon/memory_test -run TestReconnectTunnel
+          go test -v ./psiphon/memory_test -run TestRestartController

+ 0 - 77
.travis.yml

@@ -1,77 +0,0 @@
-dist: trusty
-language: go
-sudo: required
-go:
-- 1.14.12
-addons:
-  apt_packages:
-    - libx11-dev
-    - libgles2-mesa-dev
-script:
-- cd psiphon
-- go test -race -v ./common
-- go test -race -v ./common/accesscontrol
-- go test -race -v ./common/crypto/ssh
-- go test -race -v ./common/fragmentor
-- go test -race -v ./common/obfuscator
-- go test -race -v ./common/osl
-- sudo -E env "PATH=$PATH" go test -race -v -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./common/packetman
-- go test -race -v ./common/parameters
-- go test -race -v ./common/protocol
-- go test -race -v ./common/quic
-- go test -race -v ./common/tactics
-# TODO: fix and reenable test, which is failing in TravisCI environment:
-# --- FAIL: TestTunneledTCPIPv4
-#    tun_test.go:226: startTestTCPClient failed: syscall.Connect failed: connection timed out
-#
-#- sudo -E env "PATH=$PATH" go test -race -v ./common/tun
-- go test -race -v ./common/values
-- go test -race -v ./common/wildcard
-- go test -race -v ./transferstats
-- sudo -E env "PATH=$PATH" go test -race -v -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./server
-- go test -race -v ./server/psinet
-- go test -race -v ../Server/logging/analysis
-- go test -race -v ../ClientLibrary/clientlib
-- go test -race -v
-- go test -v -covermode=count -coverprofile=common.coverprofile ./common
-- go test -v -covermode=count -coverprofile=accesscontrol.coverprofile ./common/accesscontrol
-- go test -v -covermode=count -coverprofile=ssh.coverprofile ./common/crypto/ssh
-- go test -v -covermode=count -coverprofile=fragmentor.coverprofile ./common/fragmentor
-- go test -v -covermode=count -coverprofile=obfuscator.coverprofile ./common/obfuscator
-- go test -v -covermode=count -coverprofile=osl.coverprofile ./common/osl
-- go test -v -covermode=count -coverprofile=parameters.coverprofile ./common/parameters
-- sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=packetman.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./common/packetman
-- go test -v -covermode=count -coverprofile=protocol.coverprofile ./common/protocol
-- go test -v -covermode=count -coverprofile=quic.coverprofile ./common/quic
-- go test -v -covermode=count -coverprofile=tactics.coverprofile ./common/tactics
-# TODO: see "tun" test comment above
-#- sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./common/tun
-- go test -v -covermode=count -coverprofile=values.coverprofile ./common/values
-- go test -v -covermode=count -coverprofile=wildcard.coverprofile ./common/wildcard
-- go test -v -covermode=count -coverprofile=transferstats.coverprofile ./transferstats
-- sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=server.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./server
-- go test -v -covermode=count -coverprofile=psinet.coverprofile ./server/psinet
-- go test -v -covermode=count -coverprofile=analysis.coverprofile ../Server/logging/analysis
-- go test -v -covermode=count -coverprofile=clientlib.coverprofile ../ClientLibrary/clientlib
-- go test -v -covermode=count -coverprofile=psiphon.coverprofile
-- go test -v ./memory_test -run TestReconnectTunnel
-- go test -v ./memory_test -run TestRestartController
-after_script:
-- $HOME/gopath/bin/gover
-- $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci -repotoken $COVERALLS_TOKEN
-before_install:
-- go get github.com/axw/gocov/gocov
-- go get github.com/modocache/gover
-- go get github.com/mattn/goveralls
-- if ! go get github.com/golang/tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi
-- git rev-parse --short HEAD > psiphon/git_rev
-- openssl aes-256-cbc -K $encrypted_bf83b4ab4874_key -iv $encrypted_bf83b4ab4874_iv
-  -in psiphon/controller_test.config.enc -out psiphon/controller_test.config -d
-- openssl aes-256-cbc -K $encrypted_560fd0d04977_key -iv $encrypted_560fd0d04977_iv 
-  -in psiphon/feedback_test.config.enc -out psiphon/feedback_test.config -d
-notifications:
-  slack:
-    rooms:
-      secure: jVo/BZ1iFtg4g5V+eNxETwXPnbhwVwGzN1vkHJnCLAhV/md3/uHGsZQIMfitqgrX/T+9JBVRbRezjBwfJHYLs40IJTCWt167Lz8R1NlazLyEpcGcdesG05cTl9oEcBb7X52kZt7r8ZIBwdB7W6U/E0/i41qKamiEJqISMsdOoFA=
-    on_success: always
-    on_failure: always

+ 1 - 1
MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml

@@ -1,4 +1,4 @@
 <?xml version="1.0" encoding="utf-8"?>
 <?xml version="1.0" encoding="utf-8"?>
 <full-backup-content>
 <full-backup-content>
-    <include domain="file" path="ca.psiphon.PsiphonTunnel.tunnel-core" />
+    <exclude domain="file" path="ca.psiphon.PsiphonTunnel.tunnel-core" />
 </full-backup-content>
 </full-backup-content>

+ 2 - 1
README.md

@@ -1,4 +1,5 @@
-[![Build Status](https://travis-ci.org/Psiphon-Labs/psiphon-tunnel-core.png)](https://travis-ci.org/Psiphon-Labs/psiphon-tunnel-core) [![Coverage Status](https://coveralls.io/repos/github/Psiphon-Labs/psiphon-tunnel-core/badge.svg?branch=master)](https://coveralls.io/github/Psiphon-Labs/psiphon-tunnel-core?branch=master)
+[![CI](https://github.com/Psiphon-Labs/psiphon-tunnel-core/actions/workflows/tests.yml/badge.svg)](https://github.com/Psiphon-Labs/psiphon-tunnel-core/actions/workflows/tests.yml) [![Coverage Status](https://coveralls.io/repos/github/Psiphon-Labs/psiphon-tunnel-core/badge.svg?branch=master)](https://coveralls.io/github/Psiphon-Labs/psiphon-tunnel-core?branch=master)
+
 
 
 Psiphon Tunnel Core README
 Psiphon Tunnel Core README
 ================================================================================
 ================================================================================

+ 9 - 0
psiphon/common/osl/osl.go

@@ -294,6 +294,10 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 
 
 	for _, scheme := range config.Schemes {
 	for _, scheme := range config.Schemes {
 
 
+		if scheme == nil {
+			return nil, errors.TraceNew("invalid scheme")
+		}
+
 		epoch, err := time.Parse(time.RFC3339, scheme.Epoch)
 		epoch, err := time.Parse(time.RFC3339, scheme.Epoch)
 		if err != nil {
 		if err != nil {
 			return nil, errors.Tracef("invalid epoch format: %s", err)
 			return nil, errors.Tracef("invalid epoch format: %s", err)
@@ -322,6 +326,11 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		}
 		}
 
 
 		for index, seedSpec := range scheme.SeedSpecs {
 		for index, seedSpec := range scheme.SeedSpecs {
+
+			if seedSpec == nil {
+				return nil, errors.TraceNew("invalid seed spec")
+			}
+
 			if len(seedSpec.ID) != KEY_LENGTH_BYTES {
 			if len(seedSpec.ID) != KEY_LENGTH_BYTES {
 				return nil, errors.TraceNew("invalid seed spec ID")
 				return nil, errors.TraceNew("invalid seed spec ID")
 			}
 			}

+ 49 - 0
psiphon/common/parameters/parameters.go

@@ -99,6 +99,8 @@ const (
 	InitialLimitTunnelProtocolsCandidateCount        = "InitialLimitTunnelProtocolsCandidateCount"
 	InitialLimitTunnelProtocolsCandidateCount        = "InitialLimitTunnelProtocolsCandidateCount"
 	LimitTunnelProtocolsProbability                  = "LimitTunnelProtocolsProbability"
 	LimitTunnelProtocolsProbability                  = "LimitTunnelProtocolsProbability"
 	LimitTunnelProtocols                             = "LimitTunnelProtocols"
 	LimitTunnelProtocols                             = "LimitTunnelProtocols"
+	LimitTunnelDialPortNumbersProbability            = "LimitTunnelDialPortNumbersProbability"
+	LimitTunnelDialPortNumbers                       = "LimitTunnelDialPortNumbers"
 	LimitTLSProfilesProbability                      = "LimitTLSProfilesProbability"
 	LimitTLSProfilesProbability                      = "LimitTLSProfilesProbability"
 	LimitTLSProfiles                                 = "LimitTLSProfiles"
 	LimitTLSProfiles                                 = "LimitTLSProfiles"
 	UseOnlyCustomTLSProfiles                         = "UseOnlyCustomTLSProfiles"
 	UseOnlyCustomTLSProfiles                         = "UseOnlyCustomTLSProfiles"
@@ -362,6 +364,9 @@ var defaultParameters = map[string]struct {
 	LimitTunnelProtocolsProbability: {value: 1.0, minimum: 0.0},
 	LimitTunnelProtocolsProbability: {value: 1.0, minimum: 0.0},
 	LimitTunnelProtocols:            {value: protocol.TunnelProtocols{}},
 	LimitTunnelProtocols:            {value: protocol.TunnelProtocols{}},
 
 
+	LimitTunnelDialPortNumbersProbability: {value: 1.0, minimum: 0.0},
+	LimitTunnelDialPortNumbers:            {value: TunnelProtocolPortLists{}},
+
 	LimitTLSProfilesProbability:           {value: 1.0, minimum: 0.0},
 	LimitTLSProfilesProbability:           {value: 1.0, minimum: 0.0},
 	LimitTLSProfiles:                      {value: protocol.TLSProfiles{}},
 	LimitTLSProfiles:                      {value: protocol.TLSProfiles{}},
 	UseOnlyCustomTLSProfiles:              {value: false},
 	UseOnlyCustomTLSProfiles:              {value: false},
@@ -931,6 +936,22 @@ func (p *Parameters) Set(
 					}
 					}
 					return nil, errors.Trace(err)
 					return nil, errors.Trace(err)
 				}
 				}
+			case FrontingSpecs:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
+			case TunnelProtocolPortLists:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			}
 			}
 
 
 			// Enforce any minimums. Assumes defaultParameters[name]
 			// Enforce any minimums. Assumes defaultParameters[name]
@@ -1389,3 +1410,31 @@ func (p ParametersAccessor) FrontingSpecs(name string) FrontingSpecs {
 	p.snapshot.getValue(name, &value)
 	p.snapshot.getValue(name, &value)
 	return value
 	return value
 }
 }
+
+// TunnelProtocolPortLists returns a TunnelProtocolPortLists parameter value.
+func (p ParametersAccessor) TunnelProtocolPortLists(name string) TunnelProtocolPortLists {
+
+	probabilityName := name + "Probability"
+	_, ok := p.snapshot.parameters[probabilityName]
+	if ok {
+		probabilityValue := float64(1.0)
+		p.snapshot.getValue(probabilityName, &probabilityValue)
+		if !prng.FlipWeightedCoin(probabilityValue) {
+			defaultParameter, ok := defaultParameters[name]
+			if ok {
+				defaultValue, ok := defaultParameter.value.(TunnelProtocolPortLists)
+				if ok {
+					value := make(TunnelProtocolPortLists)
+					for tunnelProtocol, portLists := range defaultValue {
+						value[tunnelProtocol] = portLists
+					}
+					return value
+				}
+			}
+		}
+	}
+
+	value := make(TunnelProtocolPortLists)
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 5 - 0
psiphon/common/parameters/parameters_test.go

@@ -154,6 +154,11 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("FrontingSpecs returned %+v expected %+v", g, v)
 				t.Fatalf("FrontingSpecs returned %+v expected %+v", g, v)
 			}
 			}
+		case TunnelProtocolPortLists:
+			g := p.Get().TunnelProtocolPortLists(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("TunnelProtocolPortLists returned %+v expected %+v", g, v)
+			}
 		default:
 		default:
 			t.Fatalf("Unhandled default type: %s", name)
 			t.Fatalf("Unhandled default type: %s", name)
 		}
 		}

+ 41 - 0
psiphon/common/parameters/portlist.go

@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+import (
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+)
+
+// TunnelProtocolPortLists is a map from tunnel protocol names (or "All") to a
+// list of port number ranges.
+type TunnelProtocolPortLists map[string]*common.PortList
+
+// Validate checks that tunnel protocol names are valid.
+func (lists TunnelProtocolPortLists) Validate() error {
+	for tunnelProtocol, _ := range lists {
+		if tunnelProtocol != protocol.TUNNEL_PROTOCOLS_ALL &&
+			!common.Contains(protocol.SupportedTunnelProtocols, tunnelProtocol) {
+			return errors.TraceNew("invalid tunnel protocol for port list")
+		}
+	}
+	return nil
+}

+ 196 - 0
psiphon/common/portlist.go

@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"bytes"
+	"encoding/json"
+	"strconv"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+)
+
+// PortList provides a lookup for a configured list of IP ports and port
+// ranges. PortList is intended for use with JSON config files and is
+// initialized via UnmarshalJSON.
+//
+// A JSON port list field should look like:
+//
+// "FieldName": [1, 2, 3, [10, 20], [30, 40]]
+//
+// where the ports in the list are 1, 2, 3, 10-20, 30-40. UnmarshalJSON
+// validates that each port is in the range 1-65535 and that ranges have two
+// elements in increasing order. PortList is designed to be backwards
+// compatible with existing JSON config files where port list fields were
+// defined as `[]int`.
+type PortList struct {
+	portRanges [][2]int
+	lookup     map[int]bool
+}
+
+const lookupThreshold = 10
+
+// OptimizeLookups converts the internal port list representation to use a
+// map, which increases the performance of lookups for longer lists with an
+// increased memory footprint tradeoff. OptimizeLookups is not safe to use
+// concurrently with Lookup and should be called immediately after
+// UnmarshalJSON and before performing lookups.
+func (p *PortList) OptimizeLookups() {
+	if p == nil {
+		return
+	}
+	// TODO: does the threshold take long ranges into account?
+	if len(p.portRanges) > lookupThreshold {
+		p.lookup = make(map[int]bool)
+		for _, portRange := range p.portRanges {
+			for i := portRange[0]; i <= portRange[1]; i++ {
+				p.lookup[i] = true
+			}
+		}
+	}
+}
+
+// IsEmpty returns true for a nil PortList or a PortList with no entries.
+func (p *PortList) IsEmpty() bool {
+	if p == nil {
+		return true
+	}
+	return len(p.portRanges) == 0
+}
+
+// Lookup returns true if the specified port is in the port list and false
+// otherwise. Lookups on a nil PortList are allowed and return false.
+func (p *PortList) Lookup(port int) bool {
+	if p == nil {
+		return false
+	}
+	if p.lookup != nil {
+		return p.lookup[port]
+	}
+	for _, portRange := range p.portRanges {
+		if port >= portRange[0] && port <= portRange[1] {
+			return true
+		}
+	}
+	return false
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface.
+func (p *PortList) UnmarshalJSON(b []byte) error {
+
+	p.portRanges = nil
+	p.lookup = nil
+
+	if bytes.Equal(b, []byte("null")) {
+		return nil
+	}
+
+	decoder := json.NewDecoder(bytes.NewReader(b))
+	decoder.UseNumber()
+
+	var array []interface{}
+
+	err := decoder.Decode(&array)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	p.portRanges = make([][2]int, len(array))
+
+	for i, portRange := range array {
+
+		var startPort, endPort int64
+
+		if portNumber, ok := portRange.(json.Number); ok {
+
+			port, err := portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+
+			startPort = port
+			endPort = port
+
+		} else if array, ok := portRange.([]interface{}); ok {
+
+			if len(array) != 2 {
+				return errors.TraceNew("invalid range size")
+			}
+
+			portNumber, ok := array[0].(json.Number)
+			if !ok {
+				return errors.TraceNew("invalid type")
+			}
+			port, err := portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+			startPort = port
+
+			portNumber, ok = array[1].(json.Number)
+			if !ok {
+				return errors.TraceNew("invalid type")
+			}
+			port, err = portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+			endPort = port
+
+		} else {
+
+			return errors.TraceNew("invalid type")
+		}
+
+		if startPort < 1 || startPort > 65535 {
+			return errors.TraceNew("invalid range start")
+		}
+
+		if endPort < 1 || endPort > 65535 || endPort < startPort {
+			return errors.TraceNew("invalid range end")
+		}
+
+		p.portRanges[i] = [2]int{int(startPort), int(endPort)}
+	}
+
+	return nil
+}
+
+// MarshalJSON implements the json.Marshaler interface.
+func (p *PortList) MarshalJSON() ([]byte, error) {
+	var json bytes.Buffer
+	json.WriteString("[")
+	for i, portRange := range p.portRanges {
+		if i > 0 {
+			json.WriteString(",")
+		}
+		if portRange[0] == portRange[1] {
+			json.WriteString(strconv.Itoa(portRange[0]))
+		} else {
+			json.WriteString("[")
+			json.WriteString(strconv.Itoa(portRange[0]))
+			json.WriteString(",")
+			json.WriteString(strconv.Itoa(portRange[1]))
+			json.WriteString("]")
+		}
+	}
+	json.WriteString("]")
+	return json.Bytes(), nil
+}

+ 222 - 0
psiphon/common/portlist_test.go

@@ -0,0 +1,222 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"encoding/json"
+	"strings"
+	"testing"
+	"unicode"
+)
+
+func TestPortList(t *testing.T) {
+
+	var p *PortList
+
+	err := json.Unmarshal([]byte("[1.5]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of float port number")
+	}
+
+	err = json.Unmarshal([]byte("[-1]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of negative port number")
+	}
+
+	err = json.Unmarshal([]byte("[0]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port number")
+	}
+
+	err = json.Unmarshal([]byte("[65536]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port number")
+	}
+
+	err = json.Unmarshal([]byte("[[2,1]]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port range")
+	}
+
+	p = nil
+
+	if p.Lookup(1) != false {
+		t.Fatalf("unexpected nil PortList Lookup result")
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected nil PortList IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[1]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	s := struct {
+		List1 *PortList
+		List2 *PortList
+	}{}
+
+	jsonString := `
+    {
+        "List1" : [1,2,[10,20],100,[1000,2000]],
+        "List2" : [3,4,5,[300,400],1000,2000,[3000,3996],3997,3998,3999,4000]
+    }
+    `
+
+	err = json.Unmarshal([]byte(jsonString), &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	// Marshal and re-Unmarshal to exercise PortList.MarshalJSON.
+
+	jsonBytes, err := json.Marshal(s)
+	if err != nil {
+		t.Fatalf("Marshal failed: %v", err)
+	}
+
+	strip := func(s string) string {
+		return strings.Map(func(r rune) rune {
+			if unicode.IsSpace(r) {
+				return -1
+			}
+			return r
+		}, s)
+	}
+
+	if strip(jsonString) != strip(string(jsonBytes)) {
+
+		t.Fatalf("unexpected JSON encoding")
+	}
+
+	err = json.Unmarshal(jsonBytes, &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	s.List1.OptimizeLookups()
+	if s.List1.lookup != nil {
+		t.Fatalf("unexpected lookup initialization")
+	}
+
+	s.List2.OptimizeLookups()
+	if s.List2.lookup == nil {
+		t.Fatalf("unexpected lookup initialization")
+	}
+
+	for port := 0; port < 65536; port++ {
+
+		lookup1 := s.List1.Lookup(port)
+		expected1 := port == 1 ||
+			port == 2 ||
+			(port >= 10 && port <= 20) ||
+			port == 100 ||
+			(port >= 1000 && port <= 2000)
+		if lookup1 != expected1 {
+			t.Fatalf("unexpected port lookup: %d %v", port, lookup1)
+		}
+
+		lookup2 := s.List2.Lookup(port)
+		expected2 := port == 3 ||
+			port == 4 ||
+			port == 5 ||
+			(port >= 300 && port <= 400) ||
+			port == 1000 || port == 2000 ||
+			(port >= 3000 && port <= 4000)
+		if lookup2 != expected2 {
+			t.Fatalf("unexpected port lookup: %d %v", port, lookup2)
+		}
+	}
+}
+
+func BenchmarkPortListLinear(b *testing.B) {
+
+	s := struct {
+		List PortList
+	}{}
+
+	jsonStruct := `
+    {
+        "List" : [1,2,3,4,5,6,7,8,9,[10,20]]
+    }
+    `
+
+	err := json.Unmarshal([]byte(jsonStruct), &s)
+	if err != nil {
+		b.Fatalf("Unmarshal failed: %v", err)
+	}
+	s.List.OptimizeLookups()
+	if s.List.lookup != nil {
+		b.Fatalf("unexpected lookup initialization")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		for port := 0; port < 65536; port++ {
+			s.List.Lookup(port)
+		}
+	}
+}
+
+func BenchmarkPortListMap(b *testing.B) {
+
+	s := struct {
+		List PortList
+	}{}
+
+	jsonStruct := `
+    {
+        "List" : [1,2,3,4,5,6,7,8,9,10,[11,20]]
+    }
+    `
+
+	err := json.Unmarshal([]byte(jsonStruct), &s)
+	if err != nil {
+		b.Fatalf("Unmarshal failed: %v", err)
+	}
+	s.List.OptimizeLookups()
+	if s.List.lookup == nil {
+		b.Fatalf("unexpected lookup initialization")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		for port := 0; port < 65536; port++ {
+			s.List.Lookup(port)
+		}
+	}
+}

+ 82 - 12
psiphon/common/protocol/serverEntry.go

@@ -490,51 +490,121 @@ type ConditionallyEnabledComponents interface {
 	RefractionNetworkingEnabled() bool
 	RefractionNetworkingEnabled() bool
 }
 }
 
 
-// GetSupportedProtocols returns a list of tunnel protocols supported
-// by the ServerEntry's capabilities.
+// TunnelProtocolPortLists is a map from tunnel protocol names (or "All") to a
+// list of port number ranges.
+type TunnelProtocolPortLists map[string]*common.PortList
+
+// GetSupportedProtocols returns a list of tunnel protocols supported by the
+// ServerEntry's capabilities and allowed by various constraints.
 func (serverEntry *ServerEntry) GetSupportedProtocols(
 func (serverEntry *ServerEntry) GetSupportedProtocols(
 	conditionallyEnabled ConditionallyEnabledComponents,
 	conditionallyEnabled ConditionallyEnabledComponents,
 	useUpstreamProxy bool,
 	useUpstreamProxy bool,
 	limitTunnelProtocols []string,
 	limitTunnelProtocols []string,
+	limitTunnelDialPortNumbers TunnelProtocolPortLists,
 	excludeIntensive bool) []string {
 	excludeIntensive bool) []string {
 
 
 	supportedProtocols := make([]string, 0)
 	supportedProtocols := make([]string, 0)
 
 
-	for _, protocol := range SupportedTunnelProtocols {
+	for _, tunnelProtocol := range SupportedTunnelProtocols {
 
 
-		if useUpstreamProxy && !TunnelProtocolSupportsUpstreamProxy(protocol) {
+		if useUpstreamProxy && !TunnelProtocolSupportsUpstreamProxy(tunnelProtocol) {
 			continue
 			continue
 		}
 		}
 
 
 		if len(limitTunnelProtocols) > 0 {
 		if len(limitTunnelProtocols) > 0 {
-			if !common.Contains(limitTunnelProtocols, protocol) {
+			if !common.Contains(limitTunnelProtocols, tunnelProtocol) {
 				continue
 				continue
 			}
 			}
 		} else {
 		} else {
-			if common.Contains(DefaultDisabledTunnelProtocols, protocol) {
+			if common.Contains(DefaultDisabledTunnelProtocols, tunnelProtocol) {
 				continue
 				continue
 			}
 			}
 		}
 		}
 
 
-		if excludeIntensive && TunnelProtocolIsResourceIntensive(protocol) {
+		if excludeIntensive && TunnelProtocolIsResourceIntensive(tunnelProtocol) {
 			continue
 			continue
 		}
 		}
 
 
-		if (TunnelProtocolUsesQUIC(protocol) && !conditionallyEnabled.QUICEnabled()) ||
-			(TunnelProtocolUsesMarionette(protocol) && !conditionallyEnabled.MarionetteEnabled()) ||
-			(TunnelProtocolUsesRefractionNetworking(protocol) &&
+		if (TunnelProtocolUsesQUIC(tunnelProtocol) && !conditionallyEnabled.QUICEnabled()) ||
+			(TunnelProtocolUsesMarionette(tunnelProtocol) && !conditionallyEnabled.MarionetteEnabled()) ||
+			(TunnelProtocolUsesRefractionNetworking(tunnelProtocol) &&
 				!conditionallyEnabled.RefractionNetworkingEnabled()) {
 				!conditionallyEnabled.RefractionNetworkingEnabled()) {
 			continue
 			continue
 		}
 		}
 
 
-		if serverEntry.SupportsProtocol(protocol) {
-			supportedProtocols = append(supportedProtocols, protocol)
+		if !serverEntry.SupportsProtocol(tunnelProtocol) {
+			continue
+		}
+
+		dialPortNumber, err := serverEntry.GetDialPortNumber(tunnelProtocol)
+		if err != nil {
+			continue
+		}
+
+		if len(limitTunnelDialPortNumbers) > 0 {
+			if portList, ok := limitTunnelDialPortNumbers[tunnelProtocol]; ok {
+				if !portList.Lookup(dialPortNumber) {
+					continue
+				}
+			} else if portList, ok := limitTunnelDialPortNumbers[TUNNEL_PROTOCOLS_ALL]; ok {
+				if !portList.Lookup(dialPortNumber) {
+					continue
+				}
+			}
 		}
 		}
 
 
+		supportedProtocols = append(supportedProtocols, tunnelProtocol)
+
 	}
 	}
 	return supportedProtocols
 	return supportedProtocols
 }
 }
 
 
+func (serverEntry *ServerEntry) GetDialPortNumber(tunnelProtocol string) (int, error) {
+
+	if !serverEntry.SupportsProtocol(tunnelProtocol) {
+		return 0, errors.TraceNew("protocol not supported")
+	}
+
+	switch tunnelProtocol {
+
+	case TUNNEL_PROTOCOL_SSH:
+		return serverEntry.SshPort, nil
+
+	case TUNNEL_PROTOCOL_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedPort, nil
+
+	case TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedTapDancePort, nil
+
+	case TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedConjurePort, nil
+
+	case TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedQUICPort, nil
+
+	case TUNNEL_PROTOCOL_FRONTED_MEEK,
+		TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
+		return 443, nil
+
+	case TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
+		return 80, nil
+
+	case TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+		TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET,
+		TUNNEL_PROTOCOL_UNFRONTED_MEEK:
+		return serverEntry.MeekServerPort, nil
+
+	case TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
+		// The port is encoded in the marionnete "format"
+		// Limitations:
+		// - not compatible with LimitDialPortNumbers
+		// - accurate port is not reported via dial_port_number
+		return -1, nil
+	}
+
+	return 0, errors.TraceNew("unknown protocol")
+}
+
 // GetSupportedTacticsProtocols returns a list of tunnel protocols,
 // GetSupportedTacticsProtocols returns a list of tunnel protocols,
 // supported by the ServerEntry's capabilities, that may be used
 // supported by the ServerEntry's capabilities, that may be used
 // for tactics requests.
 // for tactics requests.

+ 18 - 4
psiphon/config.go

@@ -751,6 +751,9 @@ type Config struct {
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	UpstreamProxyAllowAllServerEntrySources *bool
 	UpstreamProxyAllowAllServerEntrySources *bool
 
 
+	// LimitTunnelDialPortNumbers is for testing purposes.
+	LimitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
+
 	// params is the active parameters.Parameters with defaults, config values,
 	// params is the active parameters.Parameters with defaults, config values,
 	// and, optionally, tactics applied.
 	// and, optionally, tactics applied.
 	//
 	//
@@ -1604,7 +1607,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.UseOnlyCustomTLSProfiles] = *config.UseOnlyCustomTLSProfiles
 		applyParameters[parameters.UseOnlyCustomTLSProfiles] = *config.UseOnlyCustomTLSProfiles
 	}
 	}
 
 
-	if config.CustomTLSProfiles != nil {
+	if len(config.CustomTLSProfiles) > 0 {
 		applyParameters[parameters.CustomTLSProfiles] = config.CustomTLSProfiles
 		applyParameters[parameters.CustomTLSProfiles] = config.CustomTLSProfiles
 	}
 	}
 
 
@@ -1616,7 +1619,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.NoDefaultTLSSessionIDProbability] = *config.NoDefaultTLSSessionIDProbability
 		applyParameters[parameters.NoDefaultTLSSessionIDProbability] = *config.NoDefaultTLSSessionIDProbability
 	}
 	}
 
 
-	if config.DisableFrontingProviderTLSProfiles != nil {
+	if len(config.DisableFrontingProviderTLSProfiles) > 0 {
 		applyParameters[parameters.DisableFrontingProviderTLSProfiles] = config.DisableFrontingProviderTLSProfiles
 		applyParameters[parameters.DisableFrontingProviderTLSProfiles] = config.DisableFrontingProviderTLSProfiles
 	}
 	}
 
 
@@ -1660,7 +1663,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.ConjureAPIRegistrarURL] = config.ConjureAPIRegistrarURL
 		applyParameters[parameters.ConjureAPIRegistrarURL] = config.ConjureAPIRegistrarURL
 	}
 	}
 
 
-	if config.ConjureAPIRegistrarFrontingSpecs != nil {
+	if len(config.ConjureAPIRegistrarFrontingSpecs) > 0 {
 		applyParameters[parameters.ConjureAPIRegistrarFrontingSpecs] = config.ConjureAPIRegistrarFrontingSpecs
 		applyParameters[parameters.ConjureAPIRegistrarFrontingSpecs] = config.ConjureAPIRegistrarFrontingSpecs
 	}
 	}
 
 
@@ -1720,6 +1723,10 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 	}
 	}
 
 
+	if len(config.LimitTunnelDialPortNumbers) > 0 {
+		applyParameters[parameters.LimitTunnelDialPortNumbers] = config.LimitTunnelDialPortNumbers
+	}
+
 	// When adding new config dial parameters that may override tactics, also
 	// When adding new config dial parameters that may override tactics, also
 	// update setDialParametersHash.
 	// update setDialParametersHash.
 
 
@@ -1929,7 +1936,7 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.NoDefaultTLSSessionIDProbability)
 		binary.Write(hash, binary.LittleEndian, *config.NoDefaultTLSSessionIDProbability)
 	}
 	}
 
 
-	if config.DisableFrontingProviderTLSProfiles != nil {
+	if len(config.DisableFrontingProviderTLSProfiles) > 0 {
 		hash.Write([]byte("DisableFrontingProviderTLSProfiles"))
 		hash.Write([]byte("DisableFrontingProviderTLSProfiles"))
 		encodedDisableFrontingProviderTLSProfiles, _ :=
 		encodedDisableFrontingProviderTLSProfiles, _ :=
 			json.Marshal(config.DisableFrontingProviderTLSProfiles)
 			json.Marshal(config.DisableFrontingProviderTLSProfiles)
@@ -2044,6 +2051,13 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)
 	}
 	}
 
 
+	if len(config.LimitTunnelDialPortNumbers) > 0 {
+		hash.Write([]byte("LimitTunnelDialPortNumbers"))
+		encodedLimitTunnelDialPortNumbers, _ :=
+			json.Marshal(config.LimitTunnelDialPortNumbers)
+		hash.Write(encodedLimitTunnelDialPortNumbers)
+	}
+
 	config.dialParametersHash = hash.Sum(nil)
 	config.dialParametersHash = hash.Sum(nil)
 }
 }
 
 

+ 34 - 25
psiphon/controller.go

@@ -1358,15 +1358,16 @@ func (controller *Controller) triggerFetches() {
 }
 }
 
 
 type protocolSelectionConstraints struct {
 type protocolSelectionConstraints struct {
-	useUpstreamProxy                    bool
-	initialLimitProtocols               protocol.TunnelProtocols
-	initialLimitProtocolsCandidateCount int
-	limitProtocols                      protocol.TunnelProtocols
-	replayCandidateCount                int
+	useUpstreamProxy                          bool
+	initialLimitTunnelProtocols               protocol.TunnelProtocols
+	initialLimitTunnelProtocolsCandidateCount int
+	limitTunnelProtocols                      protocol.TunnelProtocols
+	limitTunnelDialPortNumbers                protocol.TunnelProtocolPortLists
+	replayCandidateCount                      int
 }
 }
 
 
 func (p *protocolSelectionConstraints) hasInitialProtocols() bool {
 func (p *protocolSelectionConstraints) hasInitialProtocols() bool {
-	return len(p.initialLimitProtocols) > 0 && p.initialLimitProtocolsCandidateCount > 0
+	return len(p.initialLimitTunnelProtocols) > 0 && p.initialLimitTunnelProtocolsCandidateCount > 0
 }
 }
 
 
 func (p *protocolSelectionConstraints) isInitialCandidate(
 func (p *protocolSelectionConstraints) isInitialCandidate(
@@ -1377,7 +1378,8 @@ func (p *protocolSelectionConstraints) isInitialCandidate(
 		len(serverEntry.GetSupportedProtocols(
 		len(serverEntry.GetSupportedProtocols(
 			conditionallyEnabledComponents{},
 			conditionallyEnabledComponents{},
 			p.useUpstreamProxy,
 			p.useUpstreamProxy,
-			p.initialLimitProtocols,
+			p.initialLimitTunnelProtocols,
+			p.limitTunnelDialPortNumbers,
 			excludeIntensive)) > 0
 			excludeIntensive)) > 0
 }
 }
 
 
@@ -1385,12 +1387,12 @@ func (p *protocolSelectionConstraints) isCandidate(
 	excludeIntensive bool,
 	excludeIntensive bool,
 	serverEntry *protocol.ServerEntry) bool {
 	serverEntry *protocol.ServerEntry) bool {
 
 
-	return len(p.limitProtocols) == 0 ||
-		len(serverEntry.GetSupportedProtocols(
-			conditionallyEnabledComponents{},
-			p.useUpstreamProxy,
-			p.limitProtocols,
-			excludeIntensive)) > 0
+	return len(serverEntry.GetSupportedProtocols(
+		conditionallyEnabledComponents{},
+		p.useUpstreamProxy,
+		p.limitTunnelProtocols,
+		p.limitTunnelDialPortNumbers,
+		excludeIntensive)) > 0
 }
 }
 
 
 func (p *protocolSelectionConstraints) canReplay(
 func (p *protocolSelectionConstraints) canReplay(
@@ -1413,16 +1415,19 @@ func (p *protocolSelectionConstraints) supportedProtocols(
 	excludeIntensive bool,
 	excludeIntensive bool,
 	serverEntry *protocol.ServerEntry) []string {
 	serverEntry *protocol.ServerEntry) []string {
 
 
-	limitProtocols := p.limitProtocols
+	limitTunnelProtocols := p.limitTunnelProtocols
+
+	if len(p.initialLimitTunnelProtocols) > 0 &&
+		p.initialLimitTunnelProtocolsCandidateCount > connectTunnelCount {
 
 
-	if len(p.initialLimitProtocols) > 0 && p.initialLimitProtocolsCandidateCount > connectTunnelCount {
-		limitProtocols = p.initialLimitProtocols
+		limitTunnelProtocols = p.initialLimitTunnelProtocols
 	}
 	}
 
 
 	return serverEntry.GetSupportedProtocols(
 	return serverEntry.GetSupportedProtocols(
 		conditionallyEnabledComponents{},
 		conditionallyEnabledComponents{},
 		p.useUpstreamProxy,
 		p.useUpstreamProxy,
-		limitProtocols,
+		limitTunnelProtocols,
+		p.limitTunnelDialPortNumbers,
 		excludeIntensive)
 		excludeIntensive)
 }
 }
 
 
@@ -1578,11 +1583,15 @@ func (controller *Controller) launchEstablishing() {
 	p := controller.config.GetParameters().Get()
 	p := controller.config.GetParameters().Get()
 
 
 	controller.protocolSelectionConstraints = &protocolSelectionConstraints{
 	controller.protocolSelectionConstraints = &protocolSelectionConstraints{
-		useUpstreamProxy:                    controller.config.UseUpstreamProxy(),
-		initialLimitProtocols:               p.TunnelProtocols(parameters.InitialLimitTunnelProtocols),
-		initialLimitProtocolsCandidateCount: p.Int(parameters.InitialLimitTunnelProtocolsCandidateCount),
-		limitProtocols:                      p.TunnelProtocols(parameters.LimitTunnelProtocols),
-		replayCandidateCount:                p.Int(parameters.ReplayCandidateCount),
+		useUpstreamProxy:                          controller.config.UseUpstreamProxy(),
+		initialLimitTunnelProtocols:               p.TunnelProtocols(parameters.InitialLimitTunnelProtocols),
+		initialLimitTunnelProtocolsCandidateCount: p.Int(parameters.InitialLimitTunnelProtocolsCandidateCount),
+		limitTunnelProtocols:                      p.TunnelProtocols(parameters.LimitTunnelProtocols),
+
+		limitTunnelDialPortNumbers: protocol.TunnelProtocolPortLists(
+			p.TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers)),
+
+		replayCandidateCount: p.Int(parameters.ReplayCandidateCount),
 	}
 	}
 
 
 	// ConnectionWorkerPoolSize may be set by tactics.
 	// ConnectionWorkerPoolSize may be set by tactics.
@@ -1626,7 +1635,7 @@ func (controller *Controller) launchEstablishing() {
 	// proceeding.
 	// proceeding.
 
 
 	awaitResponse := tunnelPoolSize > 1 ||
 	awaitResponse := tunnelPoolSize > 1 ||
-		controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount > 0
+		controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount > 0
 
 
 	// AvailableEgressRegions: after a fresh install, the outer client may not
 	// AvailableEgressRegions: after a fresh install, the outer client may not
 	// have a list of regions to display; and LimitTunnelProtocols may reduce the
 	// have a list of regions to display; and LimitTunnelProtocols may reduce the
@@ -1720,11 +1729,11 @@ func (controller *Controller) launchEstablishing() {
 		// protocols may have some bad effect, such as a firewall blocking all
 		// protocols may have some bad effect, such as a firewall blocking all
 		// traffic from a host.
 		// traffic from a host.
 
 
-		if controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount > 0 {
+		if controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount > 0 {
 
 
 			if reportResponse.initialCandidatesAnyEgressRegion == 0 {
 			if reportResponse.initialCandidatesAnyEgressRegion == 0 {
 				NoticeWarning("skipping initial limit tunnel protocols")
 				NoticeWarning("skipping initial limit tunnel protocols")
-				controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount = 0
+				controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount = 0
 
 
 				// Since we were unable to satisfy the InitialLimitTunnelProtocols
 				// Since we were unable to satisfy the InitialLimitTunnelProtocols
 				// tactic, trigger RSL, OSL, and upgrade fetches to potentially
 				// tactic, trigger RSL, OSL, and upgrade fetches to potentially

BIN
psiphon/controller_test.config.enc


+ 6 - 1
psiphon/dataStore.go

@@ -671,7 +671,11 @@ func newTargetServerEntryIterator(config *Config, isTactics bool) (bool, *Server
 			return false, nil, errors.TraceNew("TargetServerEntry does not support EgressRegion")
 			return false, nil, errors.TraceNew("TargetServerEntry does not support EgressRegion")
 		}
 		}
 
 
-		limitTunnelProtocols := config.GetParameters().Get().TunnelProtocols(parameters.LimitTunnelProtocols)
+		p := config.GetParameters().Get()
+		limitTunnelProtocols := p.TunnelProtocols(parameters.LimitTunnelProtocols)
+		limitTunnelDialPortNumbers := protocol.TunnelProtocolPortLists(
+			p.TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers))
+
 		if len(limitTunnelProtocols) > 0 {
 		if len(limitTunnelProtocols) > 0 {
 			// At the ServerEntryIterator level, only limitTunnelProtocols is applied;
 			// At the ServerEntryIterator level, only limitTunnelProtocols is applied;
 			// excludeIntensive is handled higher up.
 			// excludeIntensive is handled higher up.
@@ -679,6 +683,7 @@ func newTargetServerEntryIterator(config *Config, isTactics bool) (bool, *Server
 				conditionallyEnabledComponents{},
 				conditionallyEnabledComponents{},
 				config.UseUpstreamProxy(),
 				config.UseUpstreamProxy(),
 				limitTunnelProtocols,
 				limitTunnelProtocols,
+				limitTunnelDialPortNumbers,
 				false)) == 0 {
 				false)) == 0 {
 				return false, nil, errors.Tracef(
 				return false, nil, errors.Tracef(
 					"TargetServerEntry does not support LimitTunnelProtocols: %v", limitTunnelProtocols)
 					"TargetServerEntry does not support LimitTunnelProtocols: %v", limitTunnelProtocols)

+ 29 - 40
psiphon/dialParameters.go

@@ -26,6 +26,7 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"strconv"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
@@ -716,40 +717,27 @@ func MakeDialParameters(
 	// Set dial address fields. This portion of configuration is
 	// Set dial address fields. This portion of configuration is
 	// deterministic, given the parameters established or replayed so far.
 	// deterministic, given the parameters established or replayed so far.
 
 
-	switch dialParams.TunnelProtocol {
+	dialPortNumber, err := serverEntry.GetDialPortNumber(dialParams.TunnelProtocol)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
 
-	case protocol.TUNNEL_PROTOCOL_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort)
+	dialParams.DialPortNumber = strconv.Itoa(dialPortNumber)
 
 
-	case protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
+	switch dialParams.TunnelProtocol {
 
 
-	case protocol.TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedTapDancePort)
+	case protocol.TUNNEL_PROTOCOL_SSH,
+		protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
 
 
-	case protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedConjurePort)
+		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 
 
-	case protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedQUICPort)
+	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK,
+		protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
 
 
-	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
-		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
-		if serverEntry.MeekFrontingDisableSNI {
-			dialParams.MeekSNIServerName = ""
-			// When SNI is omitted, the transformed host name is not used.
-			dialParams.MeekTransformedHostName = false
-		} else if !dialParams.MeekTransformedHostName {
-			dialParams.MeekSNIServerName = dialParams.MeekFrontingDialAddress
-		}
-
-	case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
-		// Note: port comes from marionnete "format"
-		dialParams.DirectDialAddress = serverEntry.IpAddress
-
-	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		if serverEntry.MeekFrontingDisableSNI {
 		if serverEntry.MeekFrontingDisableSNI {
 			dialParams.MeekSNIServerName = ""
 			dialParams.MeekSNIServerName = ""
@@ -760,15 +748,17 @@ func MakeDialParameters(
 		}
 		}
 
 
 	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
 	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:80", dialParams.MeekFrontingDialAddress)
+
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		// For FRONTED HTTP, the Host header cannot be transformed.
 		// For FRONTED HTTP, the Host header cannot be transformed.
 		dialParams.MeekTransformedHostName = false
 		dialParams.MeekTransformedHostName = false
 
 
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK:
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
+
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 		if !dialParams.MeekTransformedHostName {
 		if !dialParams.MeekTransformedHostName {
-			if serverEntry.MeekServerPort == 80 {
+			if dialPortNumber == 80 {
 				dialParams.MeekHostHeader = serverEntry.IpAddress
 				dialParams.MeekHostHeader = serverEntry.IpAddress
 			} else {
 			} else {
 				dialParams.MeekHostHeader = dialParams.MeekDialAddress
 				dialParams.MeekHostHeader = dialParams.MeekDialAddress
@@ -778,17 +768,22 @@ func MakeDialParameters(
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 		protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET:
 		protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET:
 
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 		if !dialParams.MeekTransformedHostName {
 		if !dialParams.MeekTransformedHostName {
 			// Note: IP address in SNI field will be omitted.
 			// Note: IP address in SNI field will be omitted.
 			dialParams.MeekSNIServerName = serverEntry.IpAddress
 			dialParams.MeekSNIServerName = serverEntry.IpAddress
 		}
 		}
-		if serverEntry.MeekServerPort == 443 {
+		if dialPortNumber == 443 {
 			dialParams.MeekHostHeader = serverEntry.IpAddress
 			dialParams.MeekHostHeader = serverEntry.IpAddress
 		} else {
 		} else {
 			dialParams.MeekHostHeader = dialParams.MeekDialAddress
 			dialParams.MeekHostHeader = dialParams.MeekDialAddress
 		}
 		}
 
 
+	case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
+
+		// Note: port comes from marionette "format"
+		dialParams.DirectDialAddress = serverEntry.IpAddress
+
 	default:
 	default:
 		return nil, errors.Tracef(
 		return nil, errors.Tracef(
 			"unknown tunnel protocol: %s", dialParams.TunnelProtocol)
 			"unknown tunnel protocol: %s", dialParams.TunnelProtocol)
@@ -797,7 +792,7 @@ func MakeDialParameters(
 
 
 	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) {
 	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) {
 
 
-		host, port, _ := net.SplitHostPort(dialParams.MeekDialAddress)
+		host, _, _ := net.SplitHostPort(dialParams.MeekDialAddress)
 
 
 		if p.Bool(parameters.MeekDialDomainsOnly) {
 		if p.Bool(parameters.MeekDialDomainsOnly) {
 			if net.ParseIP(host) != nil {
 			if net.ParseIP(host) != nil {
@@ -806,17 +801,11 @@ func MakeDialParameters(
 			}
 			}
 		}
 		}
 
 
-		dialParams.DialPortNumber = port
-
 		// The underlying TLS will automatically disable SNI for IP address server name
 		// The underlying TLS will automatically disable SNI for IP address server name
 		// values; we have this explicit check here so we record the correct value for stats.
 		// values; we have this explicit check here so we record the correct value for stats.
 		if net.ParseIP(dialParams.MeekSNIServerName) != nil {
 		if net.ParseIP(dialParams.MeekSNIServerName) != nil {
 			dialParams.MeekSNIServerName = ""
 			dialParams.MeekSNIServerName = ""
 		}
 		}
-
-	} else {
-
-		_, dialParams.DialPortNumber, _ = net.SplitHostPort(dialParams.DirectDialAddress)
 	}
 	}
 
 
 	// Initialize/replay User-Agent header for HTTP upstream proxy and meek protocols.
 	// Initialize/replay User-Agent header for HTTP upstream proxy and meek protocols.

+ 116 - 11
psiphon/dialParameters_test.go

@@ -87,7 +87,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
-	err = clientConfig.SetParameters("tag1", true, applyParameters)
+	err = clientConfig.SetParameters("tag1", false, applyParameters)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
@@ -346,7 +346,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	// Test: no replay after change tactics
 	// Test: no replay after change tactics
 
 
 	applyParameters[parameters.ReplayDialParametersTTL] = "1s"
 	applyParameters[parameters.ReplayDialParametersTTL] = "1s"
-	err = clientConfig.SetParameters("tag2", true, applyParameters)
+	err = clientConfig.SetParameters("tag2", false, applyParameters)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
@@ -400,7 +400,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.ReplayObfuscatedQUIC] = false
 	applyParameters[parameters.ReplayObfuscatedQUIC] = false
 	applyParameters[parameters.ReplayLivenessTest] = false
 	applyParameters[parameters.ReplayLivenessTest] = false
 	applyParameters[parameters.ReplayAPIRequestPadding] = false
 	applyParameters[parameters.ReplayAPIRequestPadding] = false
-	err = clientConfig.SetParameters("tag3", true, applyParameters)
+	err = clientConfig.SetParameters("tag3", false, applyParameters)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
@@ -442,7 +442,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 
 	applyParameters[parameters.RestrictFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.RestrictFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 1.0
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 1.0
-	err = clientConfig.SetParameters("tag4", true, applyParameters)
+	err = clientConfig.SetParameters("tag4", false, applyParameters)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
@@ -462,7 +462,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	}
 
 
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 0.0
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 0.0
-	err = clientConfig.SetParameters("tag5", true, applyParameters)
+	err = clientConfig.SetParameters("tag5", false, applyParameters)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
@@ -558,6 +558,110 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	}
 }
 }
 
 
+func TestLimitTunnelDialPortNumbers(t *testing.T) {
+
+	testDataDirName, err := ioutil.TempDir("", "psiphon-limit-tunnel-dial-port-numbers-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %s", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	SetNoticeWriter(ioutil.Discard)
+
+	clientConfig := &Config{
+		PropagationChannelId: "0",
+		SponsorId:            "0",
+		DataRootDirectory:    testDataDirName,
+		NetworkIDGetter:      new(testNetworkGetter),
+	}
+
+	err = clientConfig.Commit(false)
+	if err != nil {
+		t.Fatalf("error committing configuration file: %s", err)
+	}
+
+	jsonLimitDialPortNumbers := `
+    {
+        "SSH" : [[10,11]],
+        "OSSH" : [[20,21]],
+        "QUIC-OSSH" : [[30,31]],
+        "TAPDANCE-OSSH" : [[40,41]],
+        "CONJURE-OSSH" : [[50,51]],
+        "All" : [[60,61],80,443]
+    }
+    `
+
+	var limitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
+	err = json.Unmarshal([]byte(jsonLimitDialPortNumbers), &limitTunnelDialPortNumbers)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	applyParameters := make(map[string]interface{})
+	applyParameters[parameters.LimitTunnelDialPortNumbers] = limitTunnelDialPortNumbers
+	applyParameters[parameters.LimitTunnelDialPortNumbersProbability] = 1.0
+	err = clientConfig.SetParameters("tag1", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	constraints := &protocolSelectionConstraints{
+		limitTunnelDialPortNumbers: protocol.TunnelProtocolPortLists(
+			clientConfig.GetParameters().Get().TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers)),
+	}
+
+	selectProtocol := func(serverEntry *protocol.ServerEntry) (string, bool) {
+		return constraints.selectProtocol(0, false, serverEntry)
+	}
+
+	for _, tunnelProtocol := range protocol.SupportedTunnelProtocols {
+
+		if common.Contains(protocol.DefaultDisabledTunnelProtocols, tunnelProtocol) {
+			continue
+		}
+
+		serverEntries := makeMockServerEntries(tunnelProtocol, "", 100)
+
+		selected := false
+		skipped := false
+
+		for _, serverEntry := range serverEntries {
+
+			selectedProtocol, ok := selectProtocol(serverEntry)
+
+			if ok {
+
+				if selectedProtocol != tunnelProtocol {
+					t.Fatalf("unexpected selected protocol: %s", selectedProtocol)
+				}
+
+				port, err := serverEntry.GetDialPortNumber(selectedProtocol)
+				if err != nil {
+					t.Fatalf("GetDialPortNumber failed: %s", err)
+				}
+
+				if port%10 != 0 && port%10 != 1 && !protocol.TunnelProtocolUsesFrontedMeek(selectedProtocol) {
+					t.Fatalf("unexpected dial port number: %d", port)
+				}
+
+				selected = true
+
+			} else {
+
+				skipped = true
+			}
+		}
+
+		if !selected {
+			t.Fatalf("expected at least one selected server entry: %s", tunnelProtocol)
+		}
+
+		if !skipped && !protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+			t.Fatalf("expected at least one skipped server entry: %s", tunnelProtocol)
+		}
+	}
+}
+
 func makeMockServerEntries(
 func makeMockServerEntries(
 	tunnelProtocol string,
 	tunnelProtocol string,
 	frontingProviderID string,
 	frontingProviderID string,
@@ -568,17 +672,18 @@ func makeMockServerEntries(
 	for i := 0; i < count; i++ {
 	for i := 0; i < count; i++ {
 		serverEntries[i] = &protocol.ServerEntry{
 		serverEntries[i] = &protocol.ServerEntry{
 			IpAddress:                  fmt.Sprintf("192.168.0.%d", i),
 			IpAddress:                  fmt.Sprintf("192.168.0.%d", i),
-			SshPort:                    1,
-			SshObfuscatedPort:          2,
-			SshObfuscatedQUICPort:      3,
-			SshObfuscatedTapDancePort:  4,
-			SshObfuscatedConjurePort:   5,
-			MeekServerPort:             6,
+			SshPort:                    prng.Range(10, 19),
+			SshObfuscatedPort:          prng.Range(20, 29),
+			SshObfuscatedQUICPort:      prng.Range(30, 39),
+			SshObfuscatedTapDancePort:  prng.Range(40, 49),
+			SshObfuscatedConjurePort:   prng.Range(50, 59),
+			MeekServerPort:             prng.Range(60, 69),
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
 			FrontingProviderID:         frontingProviderID,
 			FrontingProviderID:         frontingProviderID,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
+			Capabilities:               []string{protocol.GetCapability(tunnelProtocol)},
 		}
 		}
 	}
 	}
 
 

BIN
psiphon/feedback_test.config.enc


+ 6 - 7
psiphon/feedback_test.go

@@ -23,6 +23,7 @@ import (
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
 	"io/ioutil"
 	"io/ioutil"
+	"os/exec"
 	"testing"
 	"testing"
 )
 )
 
 
@@ -41,7 +42,7 @@ type Diagnostics struct {
 }
 }
 
 
 func TestFeedbackUpload(t *testing.T) {
 func TestFeedbackUpload(t *testing.T) {
-	configFileContents, err := ioutil.ReadFile("feedback_test.config")
+	configFileContents, err := ioutil.ReadFile("controller_test.config")
 	if err != nil {
 	if err != nil {
 		// Skip, don't fail, if config file is not present
 		// Skip, don't fail, if config file is not present
 		t.Skipf("error loading configuration file: %s", err)
 		t.Skipf("error loading configuration file: %s", err)
@@ -65,13 +66,11 @@ func TestFeedbackUpload(t *testing.T) {
 		t.Fatalf("error committing configuration file: %s", err)
 		t.Fatalf("error committing configuration file: %s", err)
 	}
 	}
 
 
-	// git_rev is a file which contains the shortened hash of the latest commit
-	// pointed to by HEAD, i.e. git rev-parse --short HEAD.
-
-	shortRevHash, err := ioutil.ReadFile("git_rev")
+	shortRevHash, err := exec.Command("git", "rev-parse", "--short", "HEAD").Output()
 	if err != nil {
 	if err != nil {
-		// Skip, don't fail, if git rev file is not present
-		t.Skipf("error loading git revision file: %s", err)
+		// Log, don't fail, if git rev is not available
+		t.Logf("error loading git revision file: %s", err)
+		shortRevHash = []byte("unknown")
 	}
 	}
 
 
 	// Construct feedback data which can be verified later
 	// Construct feedback data which can be verified later

+ 12 - 0
psiphon/meekConn.go

@@ -645,6 +645,14 @@ func DialMeek(
 		meek.meekObfuscatedKey = meekConfig.MeekObfuscatedKey
 		meek.meekObfuscatedKey = meekConfig.MeekObfuscatedKey
 		meek.meekObfuscatorPaddingSeed = meekConfig.MeekObfuscatorPaddingSeed
 		meek.meekObfuscatorPaddingSeed = meekConfig.MeekObfuscatorPaddingSeed
 		meek.clientTunnelProtocol = meekConfig.ClientTunnelProtocol
 		meek.clientTunnelProtocol = meekConfig.ClientTunnelProtocol
+
+	} else if meek.mode == MeekModePlaintextRoundTrip {
+
+		// MeekModeRelay and MeekModeObfuscatedRoundTrip set the Host header
+		// implicitly via meek.url; MeekModePlaintextRoundTrip does not use
+		// meek.url; it uses the RoundTrip input request.URL instead. So the
+		// Host header is set to meekConfig.HostHeader explicitly here.
+		meek.additionalHeaders.Add("Host", meekConfig.HostHeader)
 	}
 	}
 
 
 	return meek, nil
 	return meek, nil
@@ -842,6 +850,10 @@ func (meek *MeekConn) RoundTrip(request *http.Request) (*http.Response, error) {
 
 
 	requestCtx := request.Context()
 	requestCtx := request.Context()
 
 
+	// Clone the request to apply addtional headers without modifying the input.
+	request = request.Clone(requestCtx)
+	meek.addAdditionalHeaders(request)
+
 	// The setDialerRequestContext/CloseIdleConnections concurrency note in
 	// The setDialerRequestContext/CloseIdleConnections concurrency note in
 	// ObfuscatedRoundTrip applies to RoundTrip as well.
 	// ObfuscatedRoundTrip applies to RoundTrip as well.
 
 

+ 3 - 0
psiphon/net.go

@@ -309,6 +309,9 @@ func ResolveIP(host string, conn net.Conn) (addrs []net.IP, ttls []time.Duration
 
 
 	// Process the response
 	// Process the response
 	response, err := dnsConn.ReadMsg()
 	response, err := dnsConn.ReadMsg()
+	if err == nil && response.MsgHdr.Id != query.MsgHdr.Id {
+		err = dns.ErrId
+	}
 	if err != nil {
 	if err != nil {
 		return nil, nil, errors.Trace(err)
 		return nil, nil, errors.Trace(err)
 	}
 	}

+ 16 - 8
psiphon/notice.go

@@ -425,9 +425,10 @@ func NoticeCandidateServers(
 	singletonNoticeLogger.outputNotice(
 	singletonNoticeLogger.outputNotice(
 		"CandidateServers", noticeIsDiagnostic,
 		"CandidateServers", noticeIsDiagnostic,
 		"region", region,
 		"region", region,
-		"initialLimitTunnelProtocols", constraints.initialLimitProtocols,
-		"initialLimitTunnelProtocolsCandidateCount", constraints.initialLimitProtocolsCandidateCount,
-		"limitTunnelProtocols", constraints.limitProtocols,
+		"initialLimitTunnelProtocols", constraints.initialLimitTunnelProtocols,
+		"initialLimitTunnelProtocolsCandidateCount", constraints.initialLimitTunnelProtocolsCandidateCount,
+		"limitTunnelProtocols", constraints.limitTunnelProtocols,
+		"limitTunnelDialPortNumbers", constraints.limitTunnelDialPortNumbers,
 		"replayCandidateCount", constraints.replayCandidateCount,
 		"replayCandidateCount", constraints.replayCandidateCount,
 		"initialCount", initialCount,
 		"initialCount", initialCount,
 		"count", count,
 		"count", count,
@@ -1034,14 +1035,21 @@ func GetNotice(notice []byte) (
 	var object noticeObject
 	var object noticeObject
 	err = json.Unmarshal(notice, &object)
 	err = json.Unmarshal(notice, &object)
 	if err != nil {
 	if err != nil {
-		return "", nil, err
+		return "", nil, errors.Trace(err)
 	}
 	}
-	var objectPayload interface{}
-	err = json.Unmarshal(object.Data, &objectPayload)
+
+	var data interface{}
+	err = json.Unmarshal(object.Data, &data)
 	if err != nil {
 	if err != nil {
-		return "", nil, err
+		return "", nil, errors.Trace(err)
+	}
+
+	dataValue, ok := data.(map[string]interface{})
+	if !ok {
+		return "", nil, errors.TraceNew("invalid data value")
 	}
 	}
-	return object.NoticeType, objectPayload.(map[string]interface{}), nil
+
+	return object.NoticeType, dataValue, nil
 }
 }
 
 
 // NoticeReceiver consumes a notice input stream and invokes a callback function
 // NoticeReceiver consumes a notice input stream and invokes a callback function

+ 6 - 30
psiphon/server/listener.go

@@ -25,14 +25,15 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 )
 )
 
 
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // of certain tactics parameters to accepted connections. Tactics filtering is
 // of certain tactics parameters to accepted connections. Tactics filtering is
-// limited to GeoIP attributes as the client has not yet sent API paramaters.
+// limited to GeoIP attributes as the client has not yet sent API parameters.
+// GeoIP uses the immediate peer IP, and so TacticsListener is suitable only
+// for tactics that do not require the original client GeoIP when fronted.
 type TacticsListener struct {
 type TacticsListener struct {
 	net.Listener
 	net.Listener
 	support        *SupportServices
 	support        *SupportServices
@@ -77,11 +78,14 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	// Limitation: RemoteAddr is the immediate peer IP, which is not the original
+	// client IP in the case of fronting.
 	geoIPData := listener.geoIPLookup(
 	geoIPData := listener.geoIPLookup(
 		common.IPAddressFromAddr(conn.RemoteAddr()))
 		common.IPAddressFromAddr(conn.RemoteAddr()))
 
 
 	p, err := listener.support.ServerTacticsParametersCache.Get(geoIPData)
 	p, err := listener.support.ServerTacticsParametersCache.Get(geoIPData)
 	if err != nil {
 	if err != nil {
+		conn.Close()
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
@@ -90,34 +94,6 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 		return conn, nil
 		return conn, nil
 	}
 	}
 
 
-	// Disconnect immediately if the clients tactics restricts usage of the
-	// fronting provider ID. The probability may be used to influence usage of a
-	// given fronting provider; but when only that provider works for a given
-	// client, and the probability is less than 1.0, the client can retry until
-	// it gets a successful coin flip.
-	//
-	// Clients will also skip candidates with restricted fronting provider IDs.
-	// The client-side probability, RestrictFrontingProviderIDsClientProbability,
-	// is applied independently of the server-side coin flip here.
-	//
-	//
-	// At this stage, GeoIP tactics filters are active, but handshake API
-	// parameters are not.
-	//
-	// See the comment in server.LoadConfig regarding fronting provider ID
-	// limitations.
-
-	if protocol.TunnelProtocolUsesFrontedMeek(listener.tunnelProtocol) &&
-		common.Contains(
-			p.Strings(parameters.RestrictFrontingProviderIDs),
-			listener.support.Config.GetFrontingProviderID()) {
-		if p.WeightedCoinFlip(
-			parameters.RestrictFrontingProviderIDsServerProbability) {
-			conn.Close()
-			return nil, nil
-		}
-	}
-
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	//
 	//
 	// In the OSSH case, replay is always activated and it is seeded using the
 	// In the OSSH case, replay is always activated and it is seeded using the

+ 2 - 36
psiphon/server/listener_test.go

@@ -28,7 +28,6 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 )
 )
@@ -37,8 +36,6 @@ func TestListener(t *testing.T) {
 
 
 	tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
 	tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
 
 
-	frontingProviderID := prng.HexString(8)
-
 	tacticsConfigJSONFormat := `
 	tacticsConfigJSONFormat := `
     {
     {
       "RequestPublicKey" : "%s",
       "RequestPublicKey" : "%s",
@@ -65,19 +62,6 @@ func TestListener(t *testing.T) {
               "FragmentorDownstreamMaxWriteBytes" : 1
               "FragmentorDownstreamMaxWriteBytes" : 1
             }
             }
           }
           }
-        },
-        {
-          "Filter" : {
-            "Regions": ["R3"],
-            "ISPs": ["I3"],
-            "Cities": ["C3"]
-          },
-          "Tactics" : {
-            "Parameters" : {
-              "RestrictFrontingProviderIDs" : ["%s"],
-              "RestrictFrontingProviderIDsServerProbability" : 1.0
-            }
-          }
         }
         }
       ]
       ]
     }
     }
@@ -92,7 +76,7 @@ func TestListener(t *testing.T) {
 	tacticsConfigJSON := fmt.Sprintf(
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
-		tunnelProtocol, frontingProviderID)
+		tunnelProtocol)
 
 
 	tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
 	tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
 
 
@@ -122,12 +106,6 @@ func TestListener(t *testing.T) {
 	listenerUnfragmentedGeoIPWrongCity := func(string) GeoIPData {
 	listenerUnfragmentedGeoIPWrongCity := func(string) GeoIPData {
 		return GeoIPData{Country: "R1", ISP: "I1", City: "C2"}
 		return GeoIPData{Country: "R1", ISP: "I1", City: "C2"}
 	}
 	}
-	listenerRestrictedFrontingProviderIDGeoIP := func(string) GeoIPData {
-		return GeoIPData{Country: "R3", ISP: "I3", City: "C3"}
-	}
-	listenerUnrestrictedFrontingProviderIDWrongRegion := func(string) GeoIPData {
-		return GeoIPData{Country: "R2", ISP: "I3", City: "C3"}
-	}
 
 
 	listenerTestCases := []struct {
 	listenerTestCases := []struct {
 		description      string
 		description      string
@@ -159,18 +137,6 @@ func TestListener(t *testing.T) {
 			false,
 			false,
 			true,
 			true,
 		},
 		},
-		{
-			"restricted",
-			listenerRestrictedFrontingProviderIDGeoIP,
-			false,
-			false,
-		},
-		{
-			"unrestricted-region",
-			listenerUnrestrictedFrontingProviderIDWrongRegion,
-			false,
-			true,
-		},
 	}
 	}
 
 
 	for _, testCase := range listenerTestCases {
 	for _, testCase := range listenerTestCases {
@@ -182,7 +148,7 @@ func TestListener(t *testing.T) {
 			}
 			}
 
 
 			support := &SupportServices{
 			support := &SupportServices{
-				Config:        &Config{frontingProviderID: frontingProviderID},
+				Config:        &Config{},
 				TacticsServer: tacticsServer,
 				TacticsServer: tacticsServer,
 			}
 			}
 			support.ReplayCache = NewReplayCache(support)
 			support.ReplayCache = NewReplayCache(support)

+ 66 - 24
psiphon/server/meek.go

@@ -43,6 +43,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
@@ -344,8 +345,9 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		session,
 		session,
 		underlyingConn,
 		underlyingConn,
 		endPoint,
 		endPoint,
-		clientIP,
+		endPointGeoIPData,
 		err := server.getSessionOrEndpoint(request, meekCookie)
 		err := server.getSessionOrEndpoint(request, meekCookie)
+
 	if err != nil {
 	if err != nil {
 		// Debug since session cookie errors commonly occur during
 		// Debug since session cookie errors commonly occur during
 		// normal operation.
 		// normal operation.
@@ -359,9 +361,8 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// request handler.
 		// request handler.
 
 
-		geoIPData := server.support.GeoIPService.Lookup(clientIP)
 		handled := server.support.TacticsServer.HandleEndPoint(
 		handled := server.support.TacticsServer.HandleEndPoint(
-			endPoint, common.GeoIPData(geoIPData), responseWriter, request)
+			endPoint, common.GeoIPData(*endPointGeoIPData), responseWriter, request)
 		if !handled {
 		if !handled {
 			log.WithTraceFields(LogFields{"endPoint": endPoint}).Info("unhandled endpoint")
 			log.WithTraceFields(LogFields{"endPoint": endPoint}).Info("unhandled endpoint")
 			common.TerminateHTTPConnection(responseWriter, request)
 			common.TerminateHTTPConnection(responseWriter, request)
@@ -587,7 +588,7 @@ func checkRangeHeader(request *http.Request) (int, bool) {
 // mode; or the endpoint is returned when the meek cookie indicates endpoint
 // mode; or the endpoint is returned when the meek cookie indicates endpoint
 // mode.
 // mode.
 func (server *MeekServer) getSessionOrEndpoint(
 func (server *MeekServer) getSessionOrEndpoint(
-	request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, string, error) {
+	request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, *GeoIPData, error) {
 
 
 	underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
 	underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
 
 
@@ -601,7 +602,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 		// TODO: can multiple http client connections using same session cookie
 		// TODO: can multiple http client connections using same session cookie
 		// cause race conditions on session struct?
 		// cause race conditions on session struct?
 		session.touch()
 		session.touch()
-		return existingSessionID, session, underlyingConn, "", "", nil
+		return existingSessionID, session, underlyingConn, "", nil, nil
 	}
 	}
 
 
 	// Determine the client remote address, which is used for geolocation
 	// Determine the client remote address, which is used for geolocation
@@ -610,6 +611,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// headers such as X-Forwarded-For.
 	// headers such as X-Forwarded-For.
 
 
 	clientIP := strings.Split(request.RemoteAddr, ":")[0]
 	clientIP := strings.Split(request.RemoteAddr, ":")[0]
+	usedProxyForwardedForHeader := false
+	var geoIPData GeoIPData
 
 
 	if len(server.support.Config.MeekProxyForwardedForHeaders) > 0 {
 	if len(server.support.Config.MeekProxyForwardedForHeaders) > 0 {
 		for _, header := range server.support.Config.MeekProxyForwardedForHeaders {
 		for _, header := range server.support.Config.MeekProxyForwardedForHeaders {
@@ -619,23 +622,29 @@ func (server *MeekServer) getSessionOrEndpoint(
 				// list of IPs (each proxy in a chain). The first IP should be
 				// list of IPs (each proxy in a chain). The first IP should be
 				// the client IP.
 				// the client IP.
 				proxyClientIP := strings.Split(value, ",")[0]
 				proxyClientIP := strings.Split(value, ",")[0]
-				if net.ParseIP(proxyClientIP) != nil &&
-					server.support.GeoIPService.Lookup(
-						proxyClientIP).Country != GEOIP_UNKNOWN_VALUE {
-
-					clientIP = proxyClientIP
-					break
+				if net.ParseIP(proxyClientIP) != nil {
+					proxyClientGeoIPData := server.support.GeoIPService.Lookup(proxyClientIP)
+					if proxyClientGeoIPData.Country != GEOIP_UNKNOWN_VALUE {
+						usedProxyForwardedForHeader = true
+						clientIP = proxyClientIP
+						geoIPData = proxyClientGeoIPData
+						break
+					}
 				}
 				}
 			}
 			}
 		}
 		}
 	}
 	}
 
 
+	if !usedProxyForwardedForHeader {
+		geoIPData = server.support.GeoIPService.Lookup(clientIP)
+	}
+
 	// The session is new (or expired). Treat the cookie value as a new meek
 	// The session is new (or expired). Treat the cookie value as a new meek
 	// cookie, extract the payload, and create a new session.
 	// cookie, extract the payload, and create a new session.
 
 
 	payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value)
 	payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value)
 	if err != nil {
 	if err != nil {
-		return "", nil, nil, "", "", errors.Trace(err)
+		return "", nil, nil, "", nil, errors.Trace(err)
 	}
 	}
 
 
 	// Note: this meek server ignores legacy values PsiphonClientSessionId
 	// Note: this meek server ignores legacy values PsiphonClientSessionId
@@ -644,7 +653,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	if err != nil {
 	if err != nil {
-		return "", nil, nil, "", "", errors.Trace(err)
+		return "", nil, nil, "", nil, errors.Trace(err)
 	}
 	}
 
 
 	tunnelProtocol := server.listenerTunnelProtocol
 	tunnelProtocol := server.listenerTunnelProtocol
@@ -656,7 +665,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 			server.listenerTunnelProtocol,
 			server.listenerTunnelProtocol,
 			server.support.Config.GetRunningProtocols()) {
 			server.support.Config.GetRunningProtocols()) {
 
 
-			return "", nil, nil, "", "", errors.Tracef(
+			return "", nil, nil, "", nil, errors.Tracef(
 				"invalid client tunnel protocol: %s", clientSessionData.ClientTunnelProtocol)
 				"invalid client tunnel protocol: %s", clientSessionData.ClientTunnelProtocol)
 		}
 		}
 
 
@@ -669,8 +678,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// rate limit is primarily intended to limit memory resource consumption and
 	// rate limit is primarily intended to limit memory resource consumption and
 	// not the overhead incurred by cookie validation.
 	// not the overhead incurred by cookie validation.
 
 
-	if server.rateLimit(clientIP, tunnelProtocol) {
-		return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded")
+	if server.rateLimit(clientIP, geoIPData, tunnelProtocol) {
+		return "", nil, nil, "", nil, errors.TraceNew("rate limit exceeded")
 	}
 	}
 
 
 	// Handle endpoints before enforcing CheckEstablishTunnels.
 	// Handle endpoints before enforcing CheckEstablishTunnels.
@@ -678,7 +687,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// handled by servers which would otherwise reject new tunnels.
 	// handled by servers which would otherwise reject new tunnels.
 
 
 	if clientSessionData.EndPoint != "" {
 	if clientSessionData.EndPoint != "" {
-		return "", nil, nil, clientSessionData.EndPoint, clientIP, nil
+		return "", nil, nil, clientSessionData.EndPoint, &geoIPData, nil
 	}
 	}
 
 
 	// Don't create new sessions when not establishing. A subsequent SSH handshake
 	// Don't create new sessions when not establishing. A subsequent SSH handshake
@@ -686,7 +695,42 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	if server.support.TunnelServer != nil &&
 	if server.support.TunnelServer != nil &&
 		!server.support.TunnelServer.CheckEstablishTunnels() {
 		!server.support.TunnelServer.CheckEstablishTunnels() {
-		return "", nil, nil, "", "", errors.TraceNew("not establishing tunnels")
+		return "", nil, nil, "", nil, errors.TraceNew("not establishing tunnels")
+	}
+
+	// Disconnect immediately if the tactics for the client restricts usage of
+	// the fronting provider ID. The probability may be used to influence
+	// usage of a given fronting provider; but when only that provider works
+	// for a given client, and the probability is less than 1.0, the client
+	// can retry until it gets a successful coin flip.
+	//
+	// Clients will also skip candidates with restricted fronting provider IDs.
+	// The client-side probability, RestrictFrontingProviderIDsClientProbability,
+	// is applied independently of the server-side coin flip here.
+	//
+	// At this stage, GeoIP tactics filters are active, but handshake API
+	// parameters are not.
+	//
+	// See the comment in server.LoadConfig regarding fronting provider ID
+	// limitations.
+
+	if protocol.TunnelProtocolUsesFrontedMeek(server.listenerTunnelProtocol) &&
+		server.support.ServerTacticsParametersCache != nil {
+
+		p, err := server.support.ServerTacticsParametersCache.Get(geoIPData)
+		if err != nil {
+			return "", nil, nil, "", nil, errors.Trace(err)
+		}
+
+		if !p.IsNil() &&
+			common.Contains(
+				p.Strings(parameters.RestrictFrontingProviderIDs),
+				server.support.Config.GetFrontingProviderID()) {
+			if p.WeightedCoinFlip(
+				parameters.RestrictFrontingProviderIDsServerProbability) {
+				return "", nil, nil, "", nil, errors.TraceNew("restricted fronting provider")
+			}
+		}
 	}
 	}
 
 
 	// Create a new session
 	// Create a new session
@@ -736,7 +780,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
 	if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
 		sessionID, err = makeMeekSessionID()
 		sessionID, err = makeMeekSessionID()
 		if err != nil {
 		if err != nil {
-			return "", nil, nil, "", "", errors.Trace(err)
+			return "", nil, nil, "", nil, errors.Trace(err)
 		}
 		}
 	}
 	}
 
 
@@ -748,10 +792,11 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// will close when session.delete calls Close() on the meekConn.
 	// will close when session.delete calls Close() on the meekConn.
 	server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
 	server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
 
 
-	return sessionID, session, underlyingConn, "", "", nil
+	return sessionID, session, underlyingConn, "", nil, nil
 }
 }
 
 
-func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool {
+func (server *MeekServer) rateLimit(
+	clientIP string, geoIPData GeoIPData, tunnelProtocol string) bool {
 
 
 	historySize,
 	historySize,
 		thresholdSeconds,
 		thresholdSeconds,
@@ -774,9 +819,6 @@ func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool
 
 
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 
 
-		// TODO: avoid redundant GeoIP lookups?
-		geoIPData := server.support.GeoIPService.Lookup(clientIP)
-
 		if len(regions) > 0 {
 		if len(regions) > 0 {
 			if !common.Contains(regions, geoIPData.Country) {
 			if !common.Contains(regions, geoIPData.Country) {
 				return false
 				return false

+ 75 - 7
psiphon/server/meek_test.go

@@ -25,8 +25,10 @@ import (
 	crypto_rand "crypto/rand"
 	crypto_rand "crypto/rand"
 	"encoding/base64"
 	"encoding/base64"
 	"fmt"
 	"fmt"
+	"io/ioutil"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
+	"path/filepath"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"syscall"
 	"syscall"
@@ -38,6 +40,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"golang.org/x/crypto/nacl/box"
 	"golang.org/x/crypto/nacl/box"
 )
 )
 
 
@@ -245,6 +248,7 @@ func TestMeekResiliency(t *testing.T) {
 		},
 		},
 		TrafficRulesSet: &TrafficRulesSet{},
 		TrafficRulesSet: &TrafficRulesSet{},
 	}
 	}
+	mockSupport.GeoIPService, _ = NewGeoIPService([]string{})
 
 
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
@@ -401,19 +405,73 @@ func (interruptor *fileDescriptorInterruptor) BindToDevice(fileDescriptor int) (
 }
 }
 
 
 func TestMeekRateLimiter(t *testing.T) {
 func TestMeekRateLimiter(t *testing.T) {
-	runTestMeekRateLimiter(t, true)
-	runTestMeekRateLimiter(t, false)
+	runTestMeekAccessControl(t, true, false)
+	runTestMeekAccessControl(t, false, false)
 }
 }
 
 
-func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
+func TestMeekRestrictFrontingProviders(t *testing.T) {
+	runTestMeekAccessControl(t, false, true)
+	runTestMeekAccessControl(t, false, false)
+}
+
+func runTestMeekAccessControl(t *testing.T, rateLimit, restrictProvider bool) {
 
 
 	attempts := 10
 	attempts := 10
 
 
 	allowedConnections := 5
 	allowedConnections := 5
+
 	if !rateLimit {
 	if !rateLimit {
 		allowedConnections = 10
 		allowedConnections = 10
 	}
 	}
 
 
+	if restrictProvider {
+		allowedConnections = 0
+	}
+
+	// Configure tactics
+
+	frontingProviderID := prng.HexString(8)
+
+	tacticsConfigJSONFormat := `
+    {
+      "RequestPublicKey" : "%s",
+      "RequestPrivateKey" : "%s",
+      "RequestObfuscatedKey" : "%s",
+      "DefaultTactics" : {
+        "TTL" : "60s",
+        "Probability" : 1.0,
+        "Parameters" : {
+          "RestrictFrontingProviderIDs" : ["%s"],
+          "RestrictFrontingProviderIDsServerProbability" : 1.0
+        }
+      }
+    }
+    `
+
+	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err :=
+		tactics.GenerateKeys()
+	if err != nil {
+		t.Fatalf("error generating tactics keys: %s", err)
+	}
+
+	restrictFrontingProviderID := ""
+
+	if restrictProvider {
+		restrictFrontingProviderID = frontingProviderID
+	}
+
+	tacticsConfigJSON := fmt.Sprintf(
+		tacticsConfigJSONFormat,
+		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+		restrictFrontingProviderID)
+
+	tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
+
+	err = ioutil.WriteFile(tacticsConfigFilename, []byte(tacticsConfigJSON), 0600)
+	if err != nil {
+		t.Fatalf("error paving tactics config file: %s", err)
+	}
+
 	// Run meek server
 	// Run meek server
 
 
 	rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
 	rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
@@ -424,11 +482,11 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 	meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
 	meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
 	meekObfuscatedKey := prng.HexString(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 	meekObfuscatedKey := prng.HexString(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 
 
-	tunnelProtocol := protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK
+	tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
 
 
 	meekRateLimiterTunnelProtocols := []string{tunnelProtocol}
 	meekRateLimiterTunnelProtocols := []string{tunnelProtocol}
 	if !rateLimit {
 	if !rateLimit {
-		meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS}
+		meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_FRONTED_MEEK}
 	}
 	}
 
 
 	mockSupport := &SupportServices{
 	mockSupport := &SupportServices{
@@ -436,6 +494,7 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 			MeekObfuscatedKey:              meekObfuscatedKey,
 			MeekObfuscatedKey:              meekObfuscatedKey,
 			MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey,
 			MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey,
 			TunnelProtocolPorts:            map[string]int{tunnelProtocol: 0},
 			TunnelProtocolPorts:            map[string]int{tunnelProtocol: 0},
+			frontingProviderID:             frontingProviderID,
 		},
 		},
 		TrafficRulesSet: &TrafficRulesSet{
 		TrafficRulesSet: &TrafficRulesSet{
 			MeekRateLimiterHistorySize:                   allowedConnections,
 			MeekRateLimiterHistorySize:                   allowedConnections,
@@ -445,6 +504,15 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 			MeekRateLimiterReapHistoryFrequencySeconds:   1,
 			MeekRateLimiterReapHistoryFrequencySeconds:   1,
 		},
 		},
 	}
 	}
+	mockSupport.GeoIPService, _ = NewGeoIPService([]string{})
+
+	tacticsServer, err := tactics.NewServer(nil, nil, nil, tacticsConfigFilename)
+	if err != nil {
+		t.Fatalf("tactics.NewServer failed: %s", err)
+	}
+
+	mockSupport.TacticsServer = tacticsServer
+	mockSupport.ServerTacticsParametersCache = NewServerTacticsParametersCache(mockSupport)
 
 
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
@@ -567,8 +635,8 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 		totalFailures != attempts-totalConnections {
 		totalFailures != attempts-totalConnections {
 
 
 		t.Fatalf(
 		t.Fatalf(
-			"Unexpected results: %d connections, %d failures",
-			totalConnections, totalFailures)
+			"Unexpected results: %d connections, %d failures, %d allowed",
+			totalConnections, totalFailures, allowedConnections)
 	}
 	}
 
 
 	// Graceful shutdown
 	// Graceful shutdown

+ 1 - 1
psiphon/server/psinet/psinet.go

@@ -396,7 +396,7 @@ func calculateBucketCount(length int) int {
 func bucketizeServerList(servers []*DiscoveryServer, bucketCount int) [][]*DiscoveryServer {
 func bucketizeServerList(servers []*DiscoveryServer, bucketCount int) [][]*DiscoveryServer {
 
 
 	// This code creates the same partitions as legacy servers:
 	// This code creates the same partitions as legacy servers:
-	// https://bitbucket.org/psiphon/psiphon-circumvention-system/src/03bc1a7e51e7c85a816e370bb3a6c755fd9c6fee/Automation/psi_ops_discovery.py
+	// https://github.com/Psiphon-Inc/psiphon-automation/blob/685f91a85bcdb33a75a200d936eadcb0686eadd7/Automation/psi_ops_discovery.py
 	//
 	//
 	// Both use the same algorithm from:
 	// Both use the same algorithm from:
 	// http://stackoverflow.com/questions/2659900/python-slicing-a-list-into-n-nearly-equal-length-partitions
 	// http://stackoverflow.com/questions/2659900/python-slicing-a-list-into-n-nearly-equal-length-partitions

+ 76 - 6
psiphon/server/server_test.go

@@ -1338,6 +1338,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	expectServerBPFField := ServerBPFEnabled() && doServerTactics
 	expectServerBPFField := ServerBPFEnabled() && doServerTactics
 	expectServerPacketManipulationField := runConfig.doPacketManipulation
 	expectServerPacketManipulationField := runConfig.doPacketManipulation
 	expectBurstFields := runConfig.doBurstMonitor
 	expectBurstFields := runConfig.doBurstMonitor
+	expectTCPPortForwardDial := runConfig.doTunneledWebRequest
+	expectTCPDataTransfer := runConfig.doTunneledWebRequest && !expectTrafficFailure && !runConfig.doSplitTunnel
+	// Even with expectTrafficFailure, DNS port forwards will succeed
+	expectUDPDataTransfer := runConfig.doTunneledNTPRequest
 
 
 	select {
 	select {
 	case logFields := <-serverTunnelLog:
 	case logFields := <-serverTunnelLog:
@@ -1347,6 +1351,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectServerBPFField,
 			expectServerBPFField,
 			expectServerPacketManipulationField,
 			expectServerPacketManipulationField,
 			expectBurstFields,
 			expectBurstFields,
+			expectTCPPortForwardDial,
+			expectTCPDataTransfer,
+			expectUDPDataTransfer,
 			logFields)
 			logFields)
 		if err != nil {
 		if err != nil {
 			t.Fatalf("invalid server tunnel log fields: %s", err)
 			t.Fatalf("invalid server tunnel log fields: %s", err)
@@ -1404,6 +1411,9 @@ func checkExpectedServerTunnelLogFields(
 	expectServerBPFField bool,
 	expectServerBPFField bool,
 	expectServerPacketManipulationField bool,
 	expectServerPacketManipulationField bool,
 	expectBurstFields bool,
 	expectBurstFields bool,
+	expectTCPPortForwardDial bool,
+	expectTCPDataTransfer bool,
+	expectUDPDataTransfer bool,
 	fields map[string]interface{}) error {
 	fields map[string]interface{}) error {
 
 
 	// Limitations:
 	// Limitations:
@@ -1649,6 +1659,66 @@ func checkExpectedServerTunnelLogFields(
 		return fmt.Errorf("unexpected network_type '%s'", fields["network_type"])
 		return fmt.Errorf("unexpected network_type '%s'", fields["network_type"])
 	}
 	}
 
 
+	var checkTCPMetric func(float64) bool
+	if expectTCPPortForwardDial {
+		checkTCPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkTCPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"peak_concurrent_dialing_port_forward_count_tcp",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkTCPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
+	if expectTCPDataTransfer {
+		checkTCPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkTCPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"bytes_up_tcp",
+		"bytes_down_tcp",
+		"peak_concurrent_port_forward_count_tcp",
+		"total_port_forward_count_tcp",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkTCPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
+	var checkUDPMetric func(float64) bool
+	if expectUDPDataTransfer {
+		checkUDPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkUDPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"bytes_up_udp",
+		"bytes_down_udp",
+		"peak_concurrent_port_forward_count_udp",
+		"total_port_forward_count_udp",
+		"total_udpgw_channel_count",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkUDPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
 	return nil
 	return nil
 }
 }
 
 
@@ -1998,12 +2068,12 @@ func paveTrafficRulesFile(
 
 
 	allowTCPPorts := TCPPorts
 	allowTCPPorts := TCPPorts
 	allowUDPPorts := UDPPorts
 	allowUDPPorts := UDPPorts
-	disallowTCPPorts := "0"
-	disallowUDPPorts := "0"
+	disallowTCPPorts := "1"
+	disallowUDPPorts := "1"
 
 
 	if deny {
 	if deny {
-		allowTCPPorts = "0"
-		allowUDPPorts = "0"
+		allowTCPPorts = "1"
+		allowUDPPorts = "1"
 		disallowTCPPorts = TCPPorts
 		disallowTCPPorts = TCPPorts
 		disallowUDPPorts = UDPPorts
 		disallowUDPPorts = UDPPorts
 	}
 	}
@@ -2033,8 +2103,8 @@ func paveTrafficRulesFile(
                 "ReadUnthrottledBytes": %d,
                 "ReadUnthrottledBytes": %d,
                 "WriteUnthrottledBytes": %d
                 "WriteUnthrottledBytes": %d
             },
             },
-            "AllowTCPPorts" : [0],
-            "AllowUDPPorts" : [0],
+            "AllowTCPPorts" : [1],
+            "AllowUDPPorts" : [1],
             "MeekRateLimiterHistorySize" : 10,
             "MeekRateLimiterHistorySize" : 10,
             "MeekRateLimiterThresholdSeconds" : 1,
             "MeekRateLimiterThresholdSeconds" : 1,
             "MeekRateLimiterGarbageCollectionTriggerCount" : 1,
             "MeekRateLimiterGarbageCollectionTriggerCount" : 1,

+ 18 - 93
psiphon/server/trafficRules.go

@@ -236,21 +236,21 @@ type TrafficRules struct {
 
 
 	// AllowTCPPorts specifies a list of TCP ports that are permitted for port
 	// AllowTCPPorts specifies a list of TCP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowTCPPorts []int
+	AllowTCPPorts *common.PortList
 
 
 	// AllowUDPPorts specifies a list of UDP ports that are permitted for port
 	// AllowUDPPorts specifies a list of UDP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowUDPPorts []int
+	AllowUDPPorts *common.PortList
 
 
 	// DisallowTCPPorts specifies a list of TCP ports that are not permitted for
 	// DisallowTCPPorts specifies a list of TCP ports that are not permitted for
 	// port forwarding. DisallowTCPPorts takes priority over AllowTCPPorts and
 	// port forwarding. DisallowTCPPorts takes priority over AllowTCPPorts and
 	// AllowSubnets.
 	// AllowSubnets.
-	DisallowTCPPorts []int
+	DisallowTCPPorts *common.PortList
 
 
 	// DisallowUDPPorts specifies a list of UDP ports that are not permitted for
 	// DisallowUDPPorts specifies a list of UDP ports that are not permitted for
 	// port forwarding. DisallowUDPPorts takes priority over AllowUDPPorts and
 	// port forwarding. DisallowUDPPorts takes priority over AllowUDPPorts and
 	// AllowSubnets.
 	// AllowSubnets.
-	DisallowUDPPorts []int
+	DisallowUDPPorts *common.PortList
 
 
 	// AllowSubnets specifies a list of IP address subnets for which all TCP and
 	// AllowSubnets specifies a list of IP address subnets for which all TCP and
 	// UDP ports are allowed. This list is consulted if a port is disallowed by
 	// UDP ports are allowed. This list is consulted if a port is disallowed by
@@ -261,11 +261,6 @@ type TrafficRules struct {
 	// client sends an IP address. Domain names are not resolved before checking
 	// client sends an IP address. Domain names are not resolved before checking
 	// AllowSubnets.
 	// AllowSubnets.
 	AllowSubnets []string
 	AllowSubnets []string
-
-	allowTCPPortsLookup    map[int]bool
-	allowUDPPortsLookup    map[int]bool
-	disallowTCPPortsLookup map[int]bool
-	disallowUDPPortsLookup map[int]bool
 }
 }
 
 
 // RateLimits is a clone of common.RateLimits with pointers
 // RateLimits is a clone of common.RateLimits with pointers
@@ -434,33 +429,11 @@ func (set *TrafficRulesSet) initLookups() {
 
 
 	initTrafficRulesLookups := func(rules *TrafficRules) {
 	initTrafficRulesLookups := func(rules *TrafficRules) {
 
 
-		if len(rules.AllowTCPPorts) >= intLookupThreshold {
-			rules.allowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowTCPPorts {
-				rules.allowTCPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.AllowUDPPorts) >= intLookupThreshold {
-			rules.allowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowUDPPorts {
-				rules.allowUDPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.DisallowTCPPorts) >= intLookupThreshold {
-			rules.disallowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowTCPPorts {
-				rules.disallowTCPPortsLookup[port] = true
-			}
-		}
+		rules.AllowTCPPorts.OptimizeLookups()
+		rules.AllowUDPPorts.OptimizeLookups()
+		rules.DisallowTCPPorts.OptimizeLookups()
+		rules.DisallowUDPPorts.OptimizeLookups()
 
 
-		if len(rules.DisallowUDPPorts) >= intLookupThreshold {
-			rules.disallowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowUDPPorts {
-				rules.disallowUDPPortsLookup[port] = true
-			}
-		}
 	}
 	}
 
 
 	initTrafficRulesFilterLookups := func(filter *TrafficRulesFilter) {
 	initTrafficRulesFilterLookups := func(filter *TrafficRulesFilter) {
@@ -600,14 +573,6 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			intPtr(DEFAULT_MAX_UDP_PORT_FORWARD_COUNT)
 			intPtr(DEFAULT_MAX_UDP_PORT_FORWARD_COUNT)
 	}
 	}
 
 
-	if trafficRules.AllowTCPPorts == nil {
-		trafficRules.AllowTCPPorts = make([]int, 0)
-	}
-
-	if trafficRules.AllowUDPPorts == nil {
-		trafficRules.AllowUDPPorts = make([]int, 0)
-	}
-
 	if trafficRules.AllowSubnets == nil {
 	if trafficRules.AllowSubnets == nil {
 		trafficRules.AllowSubnets = make([]string, 0)
 		trafficRules.AllowSubnets = make([]string, 0)
 	}
 	}
@@ -800,22 +765,18 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 
 		if filteredRules.Rules.AllowTCPPorts != nil {
 		if filteredRules.Rules.AllowTCPPorts != nil {
 			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
 			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
-			trafficRules.allowTCPPortsLookup = filteredRules.Rules.allowTCPPortsLookup
 		}
 		}
 
 
 		if filteredRules.Rules.AllowUDPPorts != nil {
 		if filteredRules.Rules.AllowUDPPorts != nil {
 			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
 			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
-			trafficRules.allowUDPPortsLookup = filteredRules.Rules.allowUDPPortsLookup
 		}
 		}
 
 
 		if filteredRules.Rules.DisallowTCPPorts != nil {
 		if filteredRules.Rules.DisallowTCPPorts != nil {
 			trafficRules.DisallowTCPPorts = filteredRules.Rules.DisallowTCPPorts
 			trafficRules.DisallowTCPPorts = filteredRules.Rules.DisallowTCPPorts
-			trafficRules.disallowTCPPortsLookup = filteredRules.Rules.disallowTCPPortsLookup
 		}
 		}
 
 
 		if filteredRules.Rules.DisallowUDPPorts != nil {
 		if filteredRules.Rules.DisallowUDPPorts != nil {
 			trafficRules.DisallowUDPPorts = filteredRules.Rules.DisallowUDPPorts
 			trafficRules.DisallowUDPPorts = filteredRules.Rules.DisallowUDPPorts
-			trafficRules.disallowUDPPortsLookup = filteredRules.Rules.disallowUDPPortsLookup
 		}
 		}
 
 
 		if filteredRules.Rules.AllowSubnets != nil {
 		if filteredRules.Rules.AllowSubnets != nil {
@@ -837,34 +798,16 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 
 func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
 
-	if len(rules.DisallowTCPPorts) > 0 {
-		if rules.disallowTCPPortsLookup != nil {
-			if rules.disallowTCPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowTCPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowTCPPorts.Lookup(port) {
+		return false
 	}
 	}
 
 
-	if len(rules.AllowTCPPorts) == 0 {
+	if rules.AllowTCPPorts.IsEmpty() {
 		return true
 		return true
 	}
 	}
 
 
-	if rules.allowTCPPortsLookup != nil {
-		if rules.allowTCPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowTCPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowTCPPorts.Lookup(port) {
+		return true
 	}
 	}
 
 
 	return rules.allowSubnet(remoteIP)
 	return rules.allowSubnet(remoteIP)
@@ -872,34 +815,16 @@ func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
 
 func (rules *TrafficRules) AllowUDPPort(remoteIP net.IP, port int) bool {
 func (rules *TrafficRules) AllowUDPPort(remoteIP net.IP, port int) bool {
 
 
-	if len(rules.DisallowUDPPorts) > 0 {
-		if rules.disallowUDPPortsLookup != nil {
-			if rules.disallowUDPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowUDPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowUDPPorts.Lookup(port) {
+		return false
 	}
 	}
 
 
-	if len(rules.AllowUDPPorts) == 0 {
+	if rules.AllowUDPPorts.IsEmpty() {
 		return true
 		return true
 	}
 	}
 
 
-	if rules.allowUDPPortsLookup != nil {
-		if rules.allowUDPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowUDPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowUDPPorts.Lookup(port) {
+		return true
 	}
 	}
 
 
 	return rules.allowSubnet(remoteIP)
 	return rules.allowSubnet(remoteIP)

+ 40 - 18
psiphon/server/tunnelServer.go

@@ -1260,8 +1260,10 @@ type sshClient struct {
 	isFirstTunnelInSession               bool
 	isFirstTunnelInSession               bool
 	supportsServerRequests               bool
 	supportsServerRequests               bool
 	handshakeState                       handshakeState
 	handshakeState                       handshakeState
-	udpChannel                           ssh.Channel
+	udpgwChannelHandler                  *udpgwPortForwardMultiplexer
+	totalUdpgwChannelCount               int
 	packetTunnelChannel                  ssh.Channel
 	packetTunnelChannel                  ssh.Channel
+	totalPacketTunnelChannelCount        int
 	trafficRules                         TrafficRules
 	trafficRules                         TrafficRules
 	tcpTrafficState                      trafficState
 	tcpTrafficState                      trafficState
 	udpTrafficState                      trafficState
 	udpTrafficState                      trafficState
@@ -2495,11 +2497,11 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 
 
 	// Intercept TCP port forwards to a specified udpgw server and handle directly.
 	// Intercept TCP port forwards to a specified udpgw server and handle directly.
 	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
 	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
-	isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
+	isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
 		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
 		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
 			net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
 			net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
 
 
-	if isUDPChannel {
+	if isUdpgwChannel {
 
 
 		// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
 		// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
 		// own worker goroutine.
 		// own worker goroutine.
@@ -2507,7 +2509,7 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 		waitGroup.Add(1)
 		waitGroup.Add(1)
 		go func(channel ssh.NewChannel) {
 		go func(channel ssh.NewChannel) {
 			defer waitGroup.Done()
 			defer waitGroup.Done()
-			sshClient.handleUDPChannel(channel)
+			sshClient.handleUdpgwChannel(channel)
 		}(newChannel)
 		}(newChannel)
 
 
 	} else {
 	} else {
@@ -2558,20 +2560,39 @@ func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
 		sshClient.packetTunnelChannel.Close()
 		sshClient.packetTunnelChannel.Close()
 	}
 	}
 	sshClient.packetTunnelChannel = channel
 	sshClient.packetTunnelChannel = channel
+	sshClient.totalPacketTunnelChannelCount += 1
 	sshClient.Unlock()
 	sshClient.Unlock()
 }
 }
 
 
-// setUDPChannel sets the single UDP channel for this sshClient.
-// Each sshClient may have only one concurrent UDP channel. Each
-// UDP channel multiplexes many UDP port forwards via the udpgw
-// protocol. Any existing UDP channel is closed.
-func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
+// setUdpgwChannelHandler sets the single udpgw channel handler for this
+// sshClient. Each sshClient may have only one concurrent udpgw
+// channel/handler. Each udpgw channel multiplexes many UDP port forwards via
+// the udpgw protocol. Any existing udpgw channel/handler is closed.
+func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) bool {
 	sshClient.Lock()
 	sshClient.Lock()
-	if sshClient.udpChannel != nil {
-		sshClient.udpChannel.Close()
+	if sshClient.udpgwChannelHandler != nil {
+		previousHandler := sshClient.udpgwChannelHandler
+		sshClient.udpgwChannelHandler = nil
+
+		// stop must be run without holding the sshClient mutex lock, as the
+		// udpgw goroutines may attempt to lock the same mutex. For example,
+		// udpgwPortForwardMultiplexer.run calls sshClient.establishedPortForward
+		// which calls sshClient.allocatePortForward.
+		sshClient.Unlock()
+		previousHandler.stop()
+		sshClient.Lock()
+
+		// In case some other channel has set the sshClient.udpgwChannelHandler
+		// in the meantime, fail. The caller should discard this channel/handler.
+		if sshClient.udpgwChannelHandler != nil {
+			sshClient.Unlock()
+			return false
+		}
 	}
 	}
-	sshClient.udpChannel = channel
+	sshClient.udpgwChannelHandler = udpgwChannelHandler
+	sshClient.totalUdpgwChannelCount += 1
 	sshClient.Unlock()
 	sshClient.Unlock()
+	return true
 }
 }
 
 
 var serverTunnelStatParams = append(
 var serverTunnelStatParams = append(
@@ -2616,6 +2637,8 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
 	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
+	logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount
+	logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount
 
 
 	logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
 	logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
 	logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes
 	logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes
@@ -2645,12 +2668,11 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 		}
 		}
 	}
 	}
 
 
-	sshClient.Unlock()
-
-	// Note: unlock before use is only safe as long as referenced sshClient data,
-	// such as slices in handshakeState, is read-only after initially set.
-
+	// Retain lock when invoking LogRawFieldsWithTimestamp to block any
+	// concurrent writes to variables referenced by logFields.
 	log.LogRawFieldsWithTimestamp(logFields)
 	log.LogRawFieldsWithTimestamp(logFields)
+
+	sshClient.Unlock()
 }
 }
 
 
 var blocklistHitsStatParams = []requestParamSpec{
 var blocklistHitsStatParams = []requestParamSpec{
@@ -2827,7 +2849,7 @@ func (sshClient *sshClient) enqueueDisallowedTrafficAlertRequest() {
 
 
 	sshClient.enqueueAlertRequest(
 	sshClient.enqueueAlertRequest(
 		protocol.AlertRequest{
 		protocol.AlertRequest{
-			Reason:     protocol.PSIPHON_API_ALERT_DISALLOWED_TRAFFIC,
+			Reason:     reason,
 			ActionURLs: actionURLs,
 			ActionURLs: actionURLs,
 		})
 		})
 }
 }

+ 105 - 54
psiphon/server/udp.go

@@ -25,7 +25,6 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
-	"runtime/debug"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 
 
@@ -35,7 +34,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 )
 )
 
 
-// handleUDPChannel implements UDP port forwarding. A single UDP
+// handleUdpgwChannel implements UDP port forwarding. A single UDP
 // SSH channel follows the udpgw protocol, which multiplexes many
 // SSH channel follows the udpgw protocol, which multiplexes many
 // UDP port forwards.
 // UDP port forwards.
 //
 //
@@ -43,10 +42,10 @@ import (
 // Copyright (c) 2009, Ambroz Bizjak <[email protected]>
 // Copyright (c) 2009, Ambroz Bizjak <[email protected]>
 // https://github.com/ambrop72/badvpn
 // https://github.com/ambrop72/badvpn
 //
 //
-func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 
 
 	// Accept this channel immediately. This channel will replace any
 	// Accept this channel immediately. This channel will replace any
-	// previously existing UDP channel for this client.
+	// previously existing udpgw channel for this client.
 
 
 	sshChannel, requests, err := newChannel.Accept()
 	sshChannel, requests, err := newChannel.Accept()
 	if err != nil {
 	if err != nil {
@@ -58,33 +57,81 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	go ssh.DiscardRequests(requests)
 	go ssh.DiscardRequests(requests)
 	defer sshChannel.Close()
 	defer sshChannel.Close()
 
 
-	sshClient.setUDPChannel(sshChannel)
-
-	multiplexer := &udpPortForwardMultiplexer{
+	multiplexer := &udpgwPortForwardMultiplexer{
 		sshClient:      sshClient,
 		sshClient:      sshClient,
 		sshChannel:     sshChannel,
 		sshChannel:     sshChannel,
-		portForwards:   make(map[uint16]*udpPortForward),
+		portForwards:   make(map[uint16]*udpgwPortForward),
 		portForwardLRU: common.NewLRUConns(),
 		portForwardLRU: common.NewLRUConns(),
 		relayWaitGroup: new(sync.WaitGroup),
 		relayWaitGroup: new(sync.WaitGroup),
+		runWaitGroup:   new(sync.WaitGroup),
+	}
+
+	multiplexer.runWaitGroup.Add(1)
+
+	// setUdpgwChannelHandler will close any existing
+	// udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
+	// goroutines to first terminate and all UDP socket resources to be
+	// cleaned up.
+	//
+	// This synchronous shutdown also ensures that the
+	// concurrentPortForwardCount is reduced to 0 before installing the new
+	// udpgwPortForwardMultiplexer and its LRU object. If the older handler
+	// were to dangle with open port forwards, and concurrentPortForwardCount
+	// were to hit the max, the wrong LRU, the new one, would be used to
+	// close the LRU port forward.
+	//
+	// Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
+	// runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
+	// channel) before initialized.
+
+	if !sshClient.setUdpgwChannelHandler(multiplexer) {
+		// setUdpgwChannelHandler returns false if some other SSH channel
+		// calls setUdpgwChannelHandler in the middle of this call. In that
+		// case, discard this channel: the client's latest udpgw channel is
+		// retained.
+		return
 	}
 	}
+
 	multiplexer.run()
 	multiplexer.run()
+	multiplexer.runWaitGroup.Done()
 }
 }
 
 
-type udpPortForwardMultiplexer struct {
+type udpgwPortForwardMultiplexer struct {
 	sshClient            *sshClient
 	sshClient            *sshClient
 	sshChannelWriteMutex sync.Mutex
 	sshChannelWriteMutex sync.Mutex
 	sshChannel           ssh.Channel
 	sshChannel           ssh.Channel
 	portForwardsMutex    sync.Mutex
 	portForwardsMutex    sync.Mutex
-	portForwards         map[uint16]*udpPortForward
+	portForwards         map[uint16]*udpgwPortForward
 	portForwardLRU       *common.LRUConns
 	portForwardLRU       *common.LRUConns
 	relayWaitGroup       *sync.WaitGroup
 	relayWaitGroup       *sync.WaitGroup
+	runWaitGroup         *sync.WaitGroup
+}
+
+func (mux *udpgwPortForwardMultiplexer) stop() {
+
+	// udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
+	//
+	// stop closes the udpgw SSH channel, which will cause the run goroutine
+	// to exit its message read loop and await closure of all relayDownstream
+	// goroutines. Closing all port forward UDP conns will cause all
+	// relayDownstream to exit.
+
+	_ = mux.sshChannel.Close()
+
+	mux.portForwardsMutex.Lock()
+	for _, portForward := range mux.portForwards {
+		_ = portForward.conn.Close()
+	}
+	mux.portForwardsMutex.Unlock()
+
+	mux.runWaitGroup.Wait()
 }
 }
 
 
-func (mux *udpPortForwardMultiplexer) run() {
+func (mux *udpgwPortForwardMultiplexer) run() {
 
 
-	// In a loop, read udpgw messages from the client to this channel. Each message is
-	// a UDP packet to send upstream either via a new port forward, or on an existing
-	// port forward.
+	// In a loop, read udpgw messages from the client to this channel. Each
+	// message contains a UDP packet to send upstream either via a new port
+	// forward, or on an existing port forward.
 	//
 	//
 	// A goroutine is run to read downstream packets for each UDP port forward. All read
 	// A goroutine is run to read downstream packets for each UDP port forward. All read
 	// packets are encapsulated in udpgw protocol and sent down the channel to the client.
 	// packets are encapsulated in udpgw protocol and sent down the channel to the client.
@@ -92,16 +139,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 	// When the client disconnects or the server shuts down, the channel will close and
 	// When the client disconnects or the server shuts down, the channel will close and
 	// readUdpgwMessage will exit with EOF.
 	// readUdpgwMessage will exit with EOF.
 
 
-	// Recover from and log any unexpected panics caused by udpgw input handling bugs.
-	// Note: this covers the run() goroutine only and not relayDownstream() goroutines.
-	defer func() {
-		if e := recover(); e != nil {
-			err := errors.Tracef(
-				"udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack())
-			log.WithTraceFields(LogFields{"error": err}).Warning("run failed")
-		}
-	}()
-
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
 	for {
 	for {
 		// Note: message.packet points to the reusable memory in "buffer".
 		// Note: message.packet points to the reusable memory in "buffer".
@@ -119,27 +156,37 @@ func (mux *udpPortForwardMultiplexer) run() {
 		portForward := mux.portForwards[message.connID]
 		portForward := mux.portForwards[message.connID]
 		mux.portForwardsMutex.Unlock()
 		mux.portForwardsMutex.Unlock()
 
 
-		if portForward != nil && message.discardExistingConn {
+		// In the udpgw protocol, an existing port forward is closed when
+		// either the discard flag is set or the remote address has changed.
+
+		if portForward != nil &&
+			(message.discardExistingConn ||
+				!bytes.Equal(portForward.remoteIP, message.remoteIP) ||
+				portForward.remotePort != message.remotePort) {
+
 			// The port forward's goroutine will complete cleanup, including
 			// The port forward's goroutine will complete cleanup, including
 			// tallying stats and calling sshClient.closedPortForward.
 			// tallying stats and calling sshClient.closedPortForward.
 			// portForward.conn.Close() will signal this shutdown.
 			// portForward.conn.Close() will signal this shutdown.
-			// TODO: wait for goroutine to exit before proceeding?
 			portForward.conn.Close()
 			portForward.conn.Close()
-			portForward = nil
-		}
-
-		if portForward != nil {
 
 
-			// Verify that portForward remote address matches latest message
+			// Synchronously await the termination of the relayDownstream
+			// goroutine. This ensures that the previous goroutine won't
+			// invoke removePortForward, with the connID that will be reused
+			// for the new port forward, after this point.
+			//
+			// Limitation: this synchronous shutdown cannot prevent a "wrong
+			// remote address" error on the badvpn udpgw client, which occurs
+			// when the client recycles a port forward (setting discard) but
+			// receives, from the server, a udpgw message containing the old
+			// remote address for the previous port forward with the same
+			// conn ID. That downstream message from the server may be in
+			// flight in the SSH channel when the client discard message arrives.
+			portForward.relayWaitGroup.Wait()
 
 
-			if !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
-				portForward.remotePort != message.remotePort {
-
-				log.WithTrace().Warning("UDP port forward remote address mismatch")
-				continue
-			}
+			portForward = nil
+		}
 
 
-		} else {
+		if portForward == nil {
 
 
 			// Create a new port forward
 			// Create a new port forward
 
 
@@ -237,17 +284,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 				continue
 			}
 			}
 
 
-			portForward = &udpPortForward{
-				connID:       message.connID,
-				preambleSize: message.preambleSize,
-				remoteIP:     message.remoteIP,
-				remotePort:   message.remotePort,
-				dialIP:       dialIP,
-				conn:         conn,
-				lruEntry:     lruEntry,
-				bytesUp:      0,
-				bytesDown:    0,
-				mux:          mux,
+			portForward = &udpgwPortForward{
+				connID:         message.connID,
+				preambleSize:   message.preambleSize,
+				remoteIP:       message.remoteIP,
+				remotePort:     message.remotePort,
+				dialIP:         dialIP,
+				conn:           conn,
+				lruEntry:       lruEntry,
+				bytesUp:        0,
+				bytesDown:      0,
+				relayWaitGroup: new(sync.WaitGroup),
+				mux:            mux,
 			}
 			}
 
 
 			if message.forwardDNS {
 			if message.forwardDNS {
@@ -258,6 +306,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 			mux.portForwardsMutex.Unlock()
 
 
+			portForward.relayWaitGroup.Add(1)
 			mux.relayWaitGroup.Add(1)
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 			go portForward.relayDownstream()
 		}
 		}
@@ -276,7 +325,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 	}
 	}
 
 
-	// Cleanup all UDP port forward workers when exiting
+	// Cleanup all udpgw port forward workers when exiting
 
 
 	mux.portForwardsMutex.Lock()
 	mux.portForwardsMutex.Lock()
 	for _, portForward := range mux.portForwards {
 	for _, portForward := range mux.portForwards {
@@ -288,13 +337,13 @@ func (mux *udpPortForwardMultiplexer) run() {
 	mux.relayWaitGroup.Wait()
 	mux.relayWaitGroup.Wait()
 }
 }
 
 
-func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
+func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) {
 	mux.portForwardsMutex.Lock()
 	mux.portForwardsMutex.Lock()
 	delete(mux.portForwards, connID)
 	delete(mux.portForwards, connID)
 	mux.portForwardsMutex.Unlock()
 	mux.portForwardsMutex.Unlock()
 }
 }
 
 
-type udpPortForward struct {
+type udpgwPortForward struct {
 	// Note: 64-bit ints used with atomic operations are placed
 	// Note: 64-bit ints used with atomic operations are placed
 	// at the start of struct to ensure 64-bit alignment.
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
@@ -309,10 +358,12 @@ type udpPortForward struct {
 	dialIP            net.IP
 	dialIP            net.IP
 	conn              net.Conn
 	conn              net.Conn
 	lruEntry          *common.LRUConnsEntry
 	lruEntry          *common.LRUConnsEntry
-	mux               *udpPortForwardMultiplexer
+	relayWaitGroup    *sync.WaitGroup
+	mux               *udpgwPortForwardMultiplexer
 }
 }
 
 
-func (portForward *udpPortForward) relayDownstream() {
+func (portForward *udpgwPortForward) relayDownstream() {
+	defer portForward.relayWaitGroup.Done()
 	defer portForward.mux.relayWaitGroup.Done()
 	defer portForward.mux.relayWaitGroup.Done()
 
 
 	// Downstream UDP packets are read into the reusable memory
 	// Downstream UDP packets are read into the reusable memory

+ 5 - 0
psiphon/upstreamproxy/go-ntlm/README.md

@@ -1,3 +1,8 @@
+
+This is a fork of the package that previously existed at https://github.com/ThomsonReutersEikon/go-ntlm/ntlm. It contains several bug fixes, tagged with "[Psiphon]".
+
+Original github.com/ThomsonReutersEikon/go-ntlm/ntlm README:
+
 # NTLM Implementation for Go
 # NTLM Implementation for Go
 
 
 This is a native implementation of NTLM for Go that was implemented using the Microsoft MS-NLMP documentation available at http://msdn.microsoft.com/en-us/library/cc236621.aspx.
 This is a native implementation of NTLM for Go that was implemented using the Microsoft MS-NLMP documentation available at http://msdn.microsoft.com/en-us/library/cc236621.aspx.

+ 23 - 5
psiphon/upstreamproxy/go-ntlm/ntlm/av_pairs.go

@@ -6,6 +6,7 @@ import (
 	"bytes"
 	"bytes"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
+	"errors"
 	"fmt"
 	"fmt"
 )
 )
 
 
@@ -51,13 +52,16 @@ func (p *AvPairs) AddAvPair(avId AvPairType, bytes []byte) {
 	p.List = append(p.List, *a)
 	p.List = append(p.List, *a)
 }
 }
 
 
-func ReadAvPairs(data []byte) *AvPairs {
+func ReadAvPairs(data []byte) (*AvPairs, error) {
 	pairs := new(AvPairs)
 	pairs := new(AvPairs)
 
 
 	// Get the number of AvPairs and allocate enough AvPair structures to hold them
 	// Get the number of AvPairs and allocate enough AvPair structures to hold them
 	offset := 0
 	offset := 0
 	for i := 0; len(data) > 0 && i < 11; i++ {
 	for i := 0; len(data) > 0 && i < 11; i++ {
-		pair := ReadAvPair(data, offset)
+		pair, err := ReadAvPair(data, offset)
+		if err != nil {
+			return nil, err
+		}
 		offset = offset + 4 + int(pair.AvLen)
 		offset = offset + 4 + int(pair.AvLen)
 		pairs.List = append(pairs.List, *pair)
 		pairs.List = append(pairs.List, *pair)
 		if pair.AvId == MsvAvEOL {
 		if pair.AvId == MsvAvEOL {
@@ -65,7 +69,7 @@ func ReadAvPairs(data []byte) *AvPairs {
 		}
 		}
 	}
 	}
 
 
-	return pairs
+	return pairs, nil
 }
 }
 
 
 func (p *AvPairs) Bytes() (result []byte) {
 func (p *AvPairs) Bytes() (result []byte) {
@@ -131,12 +135,26 @@ type AvPair struct {
 	Value []byte
 	Value []byte
 }
 }
 
 
-func ReadAvPair(data []byte, offset int) *AvPair {
+func ReadAvPair(data []byte, offset int) (*AvPair, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(data) < offset+4 {
+		return nil, errors.New("invalid AvPair")
+	}
+
 	pair := new(AvPair)
 	pair := new(AvPair)
 	pair.AvId = AvPairType(binary.LittleEndian.Uint16(data[offset : offset+2]))
 	pair.AvId = AvPairType(binary.LittleEndian.Uint16(data[offset : offset+2]))
 	pair.AvLen = binary.LittleEndian.Uint16(data[offset+2 : offset+4])
 	pair.AvLen = binary.LittleEndian.Uint16(data[offset+2 : offset+4])
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(data) < offset+4+int(pair.AvLen) {
+		return nil, errors.New("invalid AvPair")
+	}
+
 	pair.Value = data[offset+4 : offset+4+int(pair.AvLen)]
 	pair.Value = data[offset+4 : offset+4+int(pair.AvLen)]
-	return pair
+	return pair, nil
 }
 }
 
 
 func (a *AvPair) UnicodeStringValue() string {
 func (a *AvPair) UnicodeStringValue() string {

+ 37 - 5
psiphon/upstreamproxy/go-ntlm/ntlm/challenge_responses.go

@@ -21,6 +21,13 @@ func (n *NtlmV1Response) String() string {
 }
 }
 
 
 func ReadNtlmV1Response(bytes []byte) (*NtlmV1Response, error) {
 func ReadNtlmV1Response(bytes []byte) (*NtlmV1Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 24 {
+		return nil, errors.New("invalid NTLM v1 response")
+	}
+
 	r := new(NtlmV1Response)
 	r := new(NtlmV1Response)
 	r.Response = bytes[0:24]
 	r.Response = bytes[0:24]
 	return r, nil
 	return r, nil
@@ -84,6 +91,13 @@ func (n *NtlmV2Response) String() string {
 }
 }
 
 
 func ReadNtlmV2Response(bytes []byte) (*NtlmV2Response, error) {
 func ReadNtlmV2Response(bytes []byte) (*NtlmV2Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 45 {
+		return nil, errors.New("invalid NTLM v2 response")
+	}
+
 	r := new(NtlmV2Response)
 	r := new(NtlmV2Response)
 	r.Response = bytes[0:16]
 	r.Response = bytes[0:16]
 	r.NtlmV2ClientChallenge = new(NtlmV2ClientChallenge)
 	r.NtlmV2ClientChallenge = new(NtlmV2ClientChallenge)
@@ -103,7 +117,11 @@ func ReadNtlmV2Response(bytes []byte) (*NtlmV2Response, error) {
 	c.ChallengeFromClient = bytes[32:40]
 	c.ChallengeFromClient = bytes[32:40]
 	// Ignoring - 4 bytes reserved
 	// Ignoring - 4 bytes reserved
 	// c.Reserved3
 	// c.Reserved3
-	c.AvPairs = ReadAvPairs(bytes[44:])
+	var err error
+	c.AvPairs, err = ReadAvPairs(bytes[44:])
+	if err != nil {
+		return nil, err
+	}
 	return r, nil
 	return r, nil
 }
 }
 
 
@@ -114,10 +132,17 @@ type LmV1Response struct {
 	Response []byte
 	Response []byte
 }
 }
 
 
-func ReadLmV1Response(bytes []byte) *LmV1Response {
+func ReadLmV1Response(bytes []byte) (*LmV1Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 24 {
+		return nil, errors.New("invalid LM v1 response")
+	}
+
 	r := new(LmV1Response)
 	r := new(LmV1Response)
 	r.Response = bytes[0:24]
 	r.Response = bytes[0:24]
-	return r
+	return r, nil
 }
 }
 
 
 func (l *LmV1Response) String() string {
 func (l *LmV1Response) String() string {
@@ -136,11 +161,18 @@ type LmV2Response struct {
 	ChallengeFromClient []byte
 	ChallengeFromClient []byte
 }
 }
 
 
-func ReadLmV2Response(bytes []byte) *LmV2Response {
+func ReadLmV2Response(bytes []byte) (*LmV2Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 24 {
+		return nil, errors.New("invalid LM v2 response")
+	}
+
 	r := new(LmV2Response)
 	r := new(LmV2Response)
 	r.Response = bytes[0:16]
 	r.Response = bytes[0:16]
 	r.ChallengeFromClient = bytes[16:24]
 	r.ChallengeFromClient = bytes[16:24]
-	return r
+	return r, nil
 }
 }
 
 
 func (l *LmV2Response) String() string {
 func (l *LmV2Response) String() string {

+ 38 - 4
psiphon/upstreamproxy/go-ntlm/ntlm/message_authenticate.go

@@ -38,7 +38,7 @@ type AuthenticateMessage struct {
 	/// MS-NLMP 2.2.1.3 - In connectionless mode, a NEGOTIATE structure that contains a set of bit flags (section 2.2.2.5) and represents the
 	/// MS-NLMP 2.2.1.3 - In connectionless mode, a NEGOTIATE structure that contains a set of bit flags (section 2.2.2.5) and represents the
 	// conclusion of negotiation—the choices the client has made from the options the server offered in the CHALLENGE_MESSAGE.
 	// conclusion of negotiation—the choices the client has made from the options the server offered in the CHALLENGE_MESSAGE.
 	// In connection-oriented mode, a NEGOTIATE structure that contains the set of bit flags (section 2.2.2.5) negotiated in
 	// In connection-oriented mode, a NEGOTIATE structure that contains the set of bit flags (section 2.2.2.5) negotiated in
-	// the previous 
+	// the previous
 	NegotiateFlags uint32 // 4 bytes
 	NegotiateFlags uint32 // 4 bytes
 
 
 	// Version (8 bytes): A VERSION structure (section 2.2.2.10) that is present only when the NTLMSSP_NEGOTIATE_VERSION
 	// Version (8 bytes): A VERSION structure (section 2.2.2.10) that is present only when the NTLMSSP_NEGOTIATE_VERSION
@@ -56,6 +56,12 @@ type AuthenticateMessage struct {
 func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessage, error) {
 func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessage, error) {
 	am := new(AuthenticateMessage)
 	am := new(AuthenticateMessage)
 
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < 12 {
+		return nil, errors.New("invalid authenticate message")
+	}
+
 	am.Signature = body[0:8]
 	am.Signature = body[0:8]
 	if !bytes.Equal(am.Signature, []byte("NTLMSSP\x00")) {
 	if !bytes.Equal(am.Signature, []byte("NTLMSSP\x00")) {
 		return nil, errors.New("Invalid NTLM message signature")
 		return nil, errors.New("Invalid NTLM message signature")
@@ -74,9 +80,12 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 	}
 	}
 
 
 	if ntlmVersion == 2 {
 	if ntlmVersion == 2 {
-		am.LmV2Response = ReadLmV2Response(am.LmChallengeResponse.Payload)
+		am.LmV2Response, err = ReadLmV2Response(am.LmChallengeResponse.Payload)
 	} else {
 	} else {
-		am.LmV1Response = ReadLmV1Response(am.LmChallengeResponse.Payload)
+		am.LmV1Response, err = ReadLmV1Response(am.LmChallengeResponse.Payload)
+	}
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
 	am.NtChallengeResponseFields, err = ReadBytePayload(20, body)
 	am.NtChallengeResponseFields, err = ReadBytePayload(20, body)
@@ -90,7 +99,6 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 	} else {
 	} else {
 		am.NtlmV1Response, err = ReadNtlmV1Response(am.NtChallengeResponseFields.Payload)
 		am.NtlmV1Response, err = ReadNtlmV1Response(am.NtChallengeResponseFields.Payload)
 	}
 	}
-
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -124,11 +132,24 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 		}
 		}
 		offset = offset + 8
 		offset = offset + 8
 
 
+		// [Psiphon]
+		// Don't panic on malformed remote input.
+		if len(body) < offset+4 {
+			return nil, errors.New("invalid authenticate message")
+		}
+
 		am.NegotiateFlags = binary.LittleEndian.Uint32(body[offset : offset+4])
 		am.NegotiateFlags = binary.LittleEndian.Uint32(body[offset : offset+4])
 		offset = offset + 4
 		offset = offset + 4
 
 
 		// Version (8 bytes): A VERSION structure (section 2.2.2.10) that is present only when the NTLMSSP_NEGOTIATE_VERSION flag is set in the NegotiateFlags field. This structure is used for debugging purposes only. In normal protocol messages, it is ignored and does not affect the NTLM message processing.<9>
 		// Version (8 bytes): A VERSION structure (section 2.2.2.10) that is present only when the NTLMSSP_NEGOTIATE_VERSION flag is set in the NegotiateFlags field. This structure is used for debugging purposes only. In normal protocol messages, it is ignored and does not affect the NTLM message processing.<9>
 		if NTLMSSP_NEGOTIATE_VERSION.IsSet(am.NegotiateFlags) {
 		if NTLMSSP_NEGOTIATE_VERSION.IsSet(am.NegotiateFlags) {
+
+			// [Psiphon]
+			// Don't panic on malformed remote input.
+			if len(body) < offset+8 {
+				return nil, errors.New("invalid authenticate message")
+			}
+
 			am.Version, err = ReadVersionStruct(body[offset : offset+8])
 			am.Version, err = ReadVersionStruct(body[offset : offset+8])
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
@@ -144,12 +165,25 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 		// there is a MIC and read it out.
 		// there is a MIC and read it out.
 		var lowestOffset = am.getLowestPayloadOffset()
 		var lowestOffset = am.getLowestPayloadOffset()
 		if lowestOffset > offset {
 		if lowestOffset > offset {
+
+			// [Psiphon]
+			// Don't panic on malformed remote input.
+			if len(body) < offset+16 {
+				return nil, errors.New("invalid authenticate message")
+			}
+
 			// MIC - 16 bytes
 			// MIC - 16 bytes
 			am.Mic = body[offset : offset+16]
 			am.Mic = body[offset : offset+16]
 			offset = offset + 16
 			offset = offset + 16
 		}
 		}
 	}
 	}
 
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < offset {
+		return nil, errors.New("invalid authenticate message")
+	}
+
 	am.Payload = body[offset:]
 	am.Payload = body[offset:]
 
 
 	return am, nil
 	return am, nil

+ 23 - 1
psiphon/upstreamproxy/go-ntlm/ntlm/message_challenge.go

@@ -56,6 +56,12 @@ type ChallengeMessage struct {
 func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 	challenge := new(ChallengeMessage)
 	challenge := new(ChallengeMessage)
 
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < 40 {
+		return nil, errors.New("invalid challenge message")
+	}
+
 	challenge.Signature = body[0:8]
 	challenge.Signature = body[0:8]
 	if !bytes.Equal(challenge.Signature, []byte("NTLMSSP\x00")) {
 	if !bytes.Equal(challenge.Signature, []byte("NTLMSSP\x00")) {
 		return challenge, errors.New("Invalid NTLM message signature")
 		return challenge, errors.New("Invalid NTLM message signature")
@@ -84,11 +90,21 @@ func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	challenge.TargetInfo = ReadAvPairs(challenge.TargetInfoPayloadStruct.Payload)
+	challenge.TargetInfo, err = ReadAvPairs(challenge.TargetInfoPayloadStruct.Payload)
+	if err != nil {
+		return nil, err
+	}
 
 
 	offset := 48
 	offset := 48
 
 
 	if NTLMSSP_NEGOTIATE_VERSION.IsSet(challenge.NegotiateFlags) {
 	if NTLMSSP_NEGOTIATE_VERSION.IsSet(challenge.NegotiateFlags) {
+
+		// [Psiphon]
+		// Don't panic on malformed remote input.
+		if len(body) < offset+8 {
+			return nil, errors.New("invalid challenge message")
+		}
+
 		challenge.Version, err = ReadVersionStruct(body[offset : offset+8])
 		challenge.Version, err = ReadVersionStruct(body[offset : offset+8])
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -96,6 +112,12 @@ func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 		offset = offset + 8
 		offset = offset + 8
 	}
 	}
 
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < offset {
+		return nil, errors.New("invalid challenge message")
+	}
+
 	challenge.Payload = body[offset:]
 	challenge.Payload = body[offset:]
 
 
 	return challenge, nil
 	return challenge, nil

+ 6 - 1
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1.go

@@ -343,7 +343,12 @@ func (n *V1ClientSession) GenerateAuthenticateMessage() (am *AuthenticateMessage
 	am.NtChallengeResponseFields, _ = CreateBytePayload(n.ntChallengeResponse)
 	am.NtChallengeResponseFields, _ = CreateBytePayload(n.ntChallengeResponse)
 	am.DomainName, _ = CreateStringPayload(n.userDomain)
 	am.DomainName, _ = CreateStringPayload(n.userDomain)
 	am.UserName, _ = CreateStringPayload(n.user)
 	am.UserName, _ = CreateStringPayload(n.user)
-	am.Workstation, _ = CreateStringPayload("SQUAREMILL")
+
+	// [Psiphon]
+	// Set a blank workstation value, which is less distinguishable than the previous hard-coded value.
+	// See also: https://github.com/Azure/go-ntlmssp/commit/5e29b886690f00c76b876ae9ab8e31e7c3509203.
+
+	am.Workstation, _ = CreateStringPayload("")
 	am.EncryptedRandomSessionKey, _ = CreateBytePayload(n.encryptedRandomSessionKey)
 	am.EncryptedRandomSessionKey, _ = CreateBytePayload(n.encryptedRandomSessionKey)
 	am.NegotiateFlags = n.NegotiateFlags
 	am.NegotiateFlags = n.NegotiateFlags
 	am.Version = &VersionStruct{ProductMajorVersion: uint8(5), ProductMinorVersion: uint8(1), ProductBuild: uint16(2600), NTLMRevisionCurrent: uint8(15)}
 	am.Version = &VersionStruct{ProductMajorVersion: uint8(5), ProductMinorVersion: uint8(1), ProductBuild: uint16(2600), NTLMRevisionCurrent: uint8(15)}

+ 5 - 5
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1_test.go

@@ -42,14 +42,14 @@ func checkV1Value(t *testing.T, name string, value []byte, expected string, err
 // would authenticate. This was due to a bug in the MS-NLMP docs. This tests for that issue
 // would authenticate. This was due to a bug in the MS-NLMP docs. This tests for that issue
 func TestNtlmV1ExtendedSessionSecurity(t *testing.T) {
 func TestNtlmV1ExtendedSessionSecurity(t *testing.T) {
 	// NTLMv1 with extended session security
 	// NTLMv1 with extended session security
-  challengeMessage := "TlRMTVNTUAACAAAAAAAAADgAAABVgphiRy3oSZvn1I4AAAAAAAAAAKIAogA4AAAABQEoCgAAAA8CAA4AUgBFAFUAVABFAFIAUwABABwAVQBLAEIAUAAtAEMAQgBUAFIATQBGAEUAMAA2AAQAFgBSAGUAdQB0AGUAcgBzAC4AbgBlAHQAAwA0AHUAawBiAHAALQBjAGIAdAByAG0AZgBlADAANgAuAFIAZQB1AHQAZQByAHMALgBuAGUAdAAFABYAUgBlAHUAdABlAHIAcwAuAG4AZQB0AAAAAAA="
-  authenticateMessage := "TlRMTVNTUAADAAAAGAAYAJgAAAAYABgAsAAAAAAAAABIAAAAOgA6AEgAAAAWABYAggAAABAAEADIAAAAVYKYYgUCzg4AAAAPMQAwADAAMAAwADEALgB3AGMAcABAAHQAaABvAG0AcwBvAG4AcgBlAHUAdABlAHIAcwAuAGMAbwBtAE4AWQBDAFMATQBTAEcAOQA5ADAAOQBRWAK3h/TIywAAAAAAAAAAAAAAAAAAAAA3tp89kZU1hs1XZp7KTyGm3XsFAT9stEDW9YXDaeYVBmBcBb//2FOu"
+	challengeMessage := "TlRMTVNTUAACAAAAAAAAADgAAABVgphiRy3oSZvn1I4AAAAAAAAAAKIAogA4AAAABQEoCgAAAA8CAA4AUgBFAFUAVABFAFIAUwABABwAVQBLAEIAUAAtAEMAQgBUAFIATQBGAEUAMAA2AAQAFgBSAGUAdQB0AGUAcgBzAC4AbgBlAHQAAwA0AHUAawBiAHAALQBjAGIAdAByAG0AZgBlADAANgAuAFIAZQB1AHQAZQByAHMALgBuAGUAdAAFABYAUgBlAHUAdABlAHIAcwAuAG4AZQB0AAAAAAA="
+	authenticateMessage := "TlRMTVNTUAADAAAAGAAYAJgAAAAYABgAsAAAAAAAAABIAAAAOgA6AEgAAAAWABYAggAAABAAEADIAAAAVYKYYgUCzg4AAAAPMQAwADAAMAAwADEALgB3AGMAcABAAHQAaABvAG0AcwBvAG4AcgBlAHUAdABlAHIAcwAuAGMAbwBtAE4AWQBDAFMATQBTAEcAOQA5ADAAOQBRWAK3h/TIywAAAAAAAAAAAAAAAAAAAAA3tp89kZU1hs1XZp7KTyGm3XsFAT9stEDW9YXDaeYVBmBcBb//2FOu"
 
 
 	challengeData, _ := base64.StdEncoding.DecodeString(challengeMessage)
 	challengeData, _ := base64.StdEncoding.DecodeString(challengeMessage)
 	c, _ := ParseChallengeMessage(challengeData)
 	c, _ := ParseChallengeMessage(challengeData)
 
 
-  authenticateData, _ := base64.StdEncoding.DecodeString(authenticateMessage)
-  msg, err := ParseAuthenticateMessage(authenticateData, 1)
+	authenticateData, _ := base64.StdEncoding.DecodeString(authenticateMessage)
+	msg, err := ParseAuthenticateMessage(authenticateData, 1)
 	if err != nil {
 	if err != nil {
 		t.Errorf("Could not process authenticate message: %s", err)
 		t.Errorf("Could not process authenticate message: %s", err)
 	}
 	}
@@ -62,7 +62,7 @@ func TestNtlmV1ExtendedSessionSecurity(t *testing.T) {
 	context.SetServerChallenge(c.ServerChallenge)
 	context.SetServerChallenge(c.ServerChallenge)
 	err = context.ProcessAuthenticateMessage(msg)
 	err = context.ProcessAuthenticateMessage(msg)
 	if err == nil {
 	if err == nil {
-		t.Errorf("This message should have failed to authenticate, but it passed", err)
+		t.Errorf("This message should have failed to authenticate, but it passed")
 	}
 	}
 }
 }
 
 

+ 6 - 2
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2.go

@@ -88,7 +88,6 @@ func (n *V2Session) Sign(message []byte) ([]byte, error) {
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-//Mildly ghetto that we expose this
 func NtlmVCommonMac(message []byte, sequenceNumber int, sealingKey, signingKey []byte, NegotiateFlags uint32) []byte {
 func NtlmVCommonMac(message []byte, sequenceNumber int, sealingKey, signingKey []byte, NegotiateFlags uint32) []byte {
 	var handle *rc4P.Cipher
 	var handle *rc4P.Cipher
 	// TODO: Need to keep track of the sequence number for connection oriented NTLM
 	// TODO: Need to keep track of the sequence number for connection oriented NTLM
@@ -378,7 +377,12 @@ func (n *V2ClientSession) GenerateAuthenticateMessage() (am *AuthenticateMessage
 	am.NtChallengeResponseFields, _ = CreateBytePayload(n.ntChallengeResponse)
 	am.NtChallengeResponseFields, _ = CreateBytePayload(n.ntChallengeResponse)
 	am.DomainName, _ = CreateStringPayload(n.userDomain)
 	am.DomainName, _ = CreateStringPayload(n.userDomain)
 	am.UserName, _ = CreateStringPayload(n.user)
 	am.UserName, _ = CreateStringPayload(n.user)
-	am.Workstation, _ = CreateStringPayload("SQUAREMILL")
+
+	// [Psiphon]
+	// Set a blank workstation value, which is less distinguishable than the previous hard-coded value.
+	// See also: https://github.com/Azure/go-ntlmssp/commit/5e29b886690f00c76b876ae9ab8e31e7c3509203.
+
+	am.Workstation, _ = CreateStringPayload("")
 	am.EncryptedRandomSessionKey, _ = CreateBytePayload(n.encryptedRandomSessionKey)
 	am.EncryptedRandomSessionKey, _ = CreateBytePayload(n.encryptedRandomSessionKey)
 	am.NegotiateFlags = n.NegotiateFlags
 	am.NegotiateFlags = n.NegotiateFlags
 	am.Mic = make([]byte, 16)
 	am.Mic = make([]byte, 16)

+ 1 - 1
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2_test.go

@@ -172,7 +172,7 @@ func TestNTLMv2WithDomain(t *testing.T) {
 
 
 	err := server.ProcessAuthenticateMessage(a)
 	err := server.ProcessAuthenticateMessage(a)
 	if err != nil {
 	if err != nil {
-		t.Error("Could not process authenticate message: %s\n", err)
+		t.Errorf("Could not process authenticate message: %s\n", err)
 	}
 	}
 }
 }
 
 

+ 14 - 0
psiphon/upstreamproxy/go-ntlm/ntlm/payload.go

@@ -6,6 +6,7 @@ import (
 	"bytes"
 	"bytes"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
+	"errors"
 )
 )
 
 
 const (
 const (
@@ -80,6 +81,12 @@ func ReadBytePayload(startByte int, bytes []byte) (*PayloadStruct, error) {
 func ReadPayloadStruct(startByte int, bytes []byte, PayloadType int) (*PayloadStruct, error) {
 func ReadPayloadStruct(startByte int, bytes []byte, PayloadType int) (*PayloadStruct, error) {
 	p := new(PayloadStruct)
 	p := new(PayloadStruct)
 
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < startByte+8 {
+		return nil, errors.New("invalid payload")
+	}
+
 	p.Type = PayloadType
 	p.Type = PayloadType
 	p.Len = binary.LittleEndian.Uint16(bytes[startByte : startByte+2])
 	p.Len = binary.LittleEndian.Uint16(bytes[startByte : startByte+2])
 	p.MaxLen = binary.LittleEndian.Uint16(bytes[startByte+2 : startByte+4])
 	p.MaxLen = binary.LittleEndian.Uint16(bytes[startByte+2 : startByte+4])
@@ -87,6 +94,13 @@ func ReadPayloadStruct(startByte int, bytes []byte, PayloadType int) (*PayloadSt
 
 
 	if p.Len > 0 {
 	if p.Len > 0 {
 		endOffset := p.Offset + uint32(p.Len)
 		endOffset := p.Offset + uint32(p.Len)
+
+		// [Psiphon]
+		// Don't panic on malformed remote input.
+		if len(bytes) < int(endOffset) {
+			return nil, errors.New("invalid payload")
+		}
+
 		p.Payload = bytes[p.Offset:endOffset]
 		p.Payload = bytes[p.Offset:endOffset]
 	}
 	}
 
 

+ 7 - 0
psiphon/upstreamproxy/go-ntlm/ntlm/version.go

@@ -5,6 +5,7 @@ package ntlm
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/binary"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"fmt"
 )
 )
 
 
@@ -19,6 +20,12 @@ type VersionStruct struct {
 func ReadVersionStruct(structSource []byte) (*VersionStruct, error) {
 func ReadVersionStruct(structSource []byte) (*VersionStruct, error) {
 	versionStruct := new(VersionStruct)
 	versionStruct := new(VersionStruct)
 
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(structSource) < 8 {
+		return nil, errors.New("invalid version struct")
+	}
+
 	versionStruct.ProductMajorVersion = uint8(structSource[0])
 	versionStruct.ProductMajorVersion = uint8(structSource[0])
 	versionStruct.ProductMinorVersion = uint8(structSource[1])
 	versionStruct.ProductMinorVersion = uint8(structSource[1])
 	versionStruct.ProductBuild = binary.LittleEndian.Uint16(structSource[2:4])
 	versionStruct.ProductBuild = binary.LittleEndian.Uint16(structSource[2:4])

+ 7 - 1
vendor/github.com/Psiphon-Labs/tls-tris/key_agreement.go

@@ -72,7 +72,13 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}
 
 
-	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), pk.(*rsa.PublicKey), preMasterSecret)
+	// [Psiphon]
+	// Backport fix: https://github.com/golang/go/commit/58bc454a11d4b3dbc03f44dfcabb9068a9c076f4
+	rsaKey, ok := pk.(*rsa.PublicKey)
+	if !ok {
+		return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
+	}
+	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
 	if err != nil {
 	if err != nil {
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}

+ 1 - 1
vendor/github.com/miekg/dns/Makefile.release

@@ -1,7 +1,7 @@
 # Makefile for releasing.
 # Makefile for releasing.
 #
 #
 # The release is controlled from version.go. The version found there is
 # The release is controlled from version.go. The version found there is
-# used to tag the git repo, we're not building any artifects so there is nothing
+# used to tag the git repo, we're not building any artifacts so there is nothing
 # to upload to github.
 # to upload to github.
 #
 #
 # * Up the version in version.go
 # * Up the version in version.go

+ 14 - 6
vendor/github.com/miekg/dns/README.md

@@ -26,7 +26,6 @@ avoiding breaking changes wherever reasonable. We support the last two versions
 A not-so-up-to-date-list-that-may-be-actually-current:
 A not-so-up-to-date-list-that-may-be-actually-current:
 
 
 * https://github.com/coredns/coredns
 * https://github.com/coredns/coredns
-* https://cloudflare.com
 * https://github.com/abh/geodns
 * https://github.com/abh/geodns
 * https://github.com/baidu/bfe
 * https://github.com/baidu/bfe
 * http://www.statdns.com/
 * http://www.statdns.com/
@@ -42,11 +41,9 @@ A not-so-up-to-date-list-that-may-be-actually-current:
 * https://github.com/StalkR/dns-reverse-proxy
 * https://github.com/StalkR/dns-reverse-proxy
 * https://github.com/tianon/rawdns
 * https://github.com/tianon/rawdns
 * https://mesosphere.github.io/mesos-dns/
 * https://mesosphere.github.io/mesos-dns/
-* https://pulse.turbobytes.com/
 * https://github.com/fcambus/statzone
 * https://github.com/fcambus/statzone
 * https://github.com/benschw/dns-clb-go
 * https://github.com/benschw/dns-clb-go
 * https://github.com/corny/dnscheck for <http://public-dns.info/>
 * https://github.com/corny/dnscheck for <http://public-dns.info/>
-* https://namesmith.io
 * https://github.com/miekg/unbound
 * https://github.com/miekg/unbound
 * https://github.com/miekg/exdns
 * https://github.com/miekg/exdns
 * https://dnslookup.org
 * https://dnslookup.org
@@ -55,22 +52,30 @@ A not-so-up-to-date-list-that-may-be-actually-current:
 * https://github.com/mehrdadrad/mylg
 * https://github.com/mehrdadrad/mylg
 * https://github.com/bamarni/dockness
 * https://github.com/bamarni/dockness
 * https://github.com/fffaraz/microdns
 * https://github.com/fffaraz/microdns
-* http://kelda.io
 * https://github.com/ipdcode/hades <https://jd.com>
 * https://github.com/ipdcode/hades <https://jd.com>
 * https://github.com/StackExchange/dnscontrol/
 * https://github.com/StackExchange/dnscontrol/
 * https://www.dnsperf.com/
 * https://www.dnsperf.com/
 * https://dnssectest.net/
 * https://dnssectest.net/
-* https://dns.apebits.com
 * https://github.com/oif/apex
 * https://github.com/oif/apex
 * https://github.com/jedisct1/dnscrypt-proxy
 * https://github.com/jedisct1/dnscrypt-proxy
 * https://github.com/jedisct1/rpdns
 * https://github.com/jedisct1/rpdns
 * https://github.com/xor-gate/sshfp
 * https://github.com/xor-gate/sshfp
 * https://github.com/rs/dnstrace
 * https://github.com/rs/dnstrace
 * https://blitiri.com.ar/p/dnss ([github mirror](https://github.com/albertito/dnss))
 * https://blitiri.com.ar/p/dnss ([github mirror](https://github.com/albertito/dnss))
-* https://github.com/semihalev/sdns
 * https://render.com
 * https://render.com
 * https://github.com/peterzen/goresolver
 * https://github.com/peterzen/goresolver
 * https://github.com/folbricht/routedns
 * https://github.com/folbricht/routedns
+* https://domainr.com/
+* https://zonedb.org/
+* https://router7.org/
+* https://github.com/fortio/dnsping
+* https://github.com/Luzilla/dnsbl_exporter
+* https://github.com/bodgit/tsig
+* https://github.com/v2fly/v2ray-core (test only)
+* https://kuma.io/
+* https://www.misaka.io/services/dns
+* https://ping.sx/dig
+
 
 
 Send pull request if you want to be listed here.
 Send pull request if you want to be listed here.
 
 
@@ -167,6 +172,9 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
 * 7873 - Domain Name System (DNS) Cookies
 * 7873 - Domain Name System (DNS) Cookies
 * 8080 - EdDSA for DNSSEC
 * 8080 - EdDSA for DNSSEC
 * 8499 - DNS Terminology
 * 8499 - DNS Terminology
+* 8659 - DNS Certification Authority Authorization (CAA) Resource Record
+* 8914 - Extended DNS Errors
+* 8976 - Message Digest for DNS Zones (ZONEMD RR)
 
 
 ## Loosely Based Upon
 ## Loosely Based Upon
 
 

+ 1 - 0
vendor/github.com/miekg/dns/acceptfunc.go

@@ -25,6 +25,7 @@ var DefaultMsgAcceptFunc MsgAcceptFunc = defaultMsgAcceptFunc
 // MsgAcceptAction represents the action to be taken.
 // MsgAcceptAction represents the action to be taken.
 type MsgAcceptAction int
 type MsgAcceptAction int
 
 
+// Allowed returned values from a MsgAcceptFunc.
 const (
 const (
 	MsgAccept               MsgAcceptAction = iota // Accept the message
 	MsgAccept               MsgAcceptAction = iota // Accept the message
 	MsgReject                                      // Reject the message with a RcodeFormatError
 	MsgReject                                      // Reject the message with a RcodeFormatError

+ 110 - 45
vendor/github.com/miekg/dns/client.go

@@ -23,6 +23,7 @@ type Conn struct {
 	net.Conn                         // a net.Conn holding the connection
 	net.Conn                         // a net.Conn holding the connection
 	UDPSize        uint16            // minimum receive buffer for UDP messages
 	UDPSize        uint16            // minimum receive buffer for UDP messages
 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
+	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
 	tsigRequestMAC string
 	tsigRequestMAC string
 }
 }
 
 
@@ -34,12 +35,13 @@ type Client struct {
 	Dialer    *net.Dialer // a net.Dialer used to set local address, timeouts and more
 	Dialer    *net.Dialer // a net.Dialer used to set local address, timeouts and more
 	// Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout,
 	// Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout,
 	// WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and
 	// WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and
-	// Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext)
+	// Client.Dialer) or context.Context.Deadline (see ExchangeContext)
 	Timeout        time.Duration
 	Timeout        time.Duration
 	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
 	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
 	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
 	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
 	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
 	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
+	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
 	SingleInflight bool              // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
 	SingleInflight bool              // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
 	group          singleflight
 	group          singleflight
 }
 }
@@ -80,6 +82,12 @@ func (c *Client) writeTimeout() time.Duration {
 
 
 // Dial connects to the address on the named network.
 // Dial connects to the address on the named network.
 func (c *Client) Dial(address string) (conn *Conn, err error) {
 func (c *Client) Dial(address string) (conn *Conn, err error) {
+	return c.DialContext(context.Background(), address)
+}
+
+// DialContext connects to the address on the named network, with a context.Context.
+// For TLS over TCP (DoT) the context isn't used yet. This will be enabled when Go 1.18 is released.
+func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
 	// create a new dialer with the appropriate timeout
 	// create a new dialer with the appropriate timeout
 	var d net.Dialer
 	var d net.Dialer
 	if c.Dialer == nil {
 	if c.Dialer == nil {
@@ -99,14 +107,22 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
 	if useTLS {
 	if useTLS {
 		network = strings.TrimSuffix(network, "-tls")
 		network = strings.TrimSuffix(network, "-tls")
 
 
+		// TODO(miekg): Enable after Go 1.18 is released, to be able to support two prev. releases.
+		/*
+			tlsDialer := tls.Dialer{
+				NetDialer: &d,
+				Config:    c.TLSConfig,
+			}
+			conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
+		*/
 		conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
 		conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
 	} else {
 	} else {
-		conn.Conn, err = d.Dial(network, address)
+		conn.Conn, err = d.DialContext(ctx, network, address)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-
+	conn.UDPSize = c.UDPSize
 	return conn, nil
 	return conn, nil
 }
 }
 
 
@@ -125,14 +141,46 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
 // To specify a local address or a timeout, the caller has to set the `Client.Dialer`
 // To specify a local address or a timeout, the caller has to set the `Client.Dialer`
 // attribute appropriately
 // attribute appropriately
 func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
 func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
+	co, err := c.Dial(address)
+
+	if err != nil {
+		return nil, 0, err
+	}
+	defer co.Close()
+	return c.ExchangeWithConn(m, co)
+}
+
+// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
+// that will be used instead of creating a new one.
+// Usage pattern with a *dns.Client:
+//
+//	c := new(dns.Client)
+//	// connection management logic goes here
+//
+//	conn := c.Dial(address)
+//	in, rtt, err := c.ExchangeWithConn(message, conn)
+//
+// This allows users of the library to implement their own connection management,
+// as opposed to Exchange, which will always use new connections and incur the added overhead
+// that entails when using "tcp" and especially "tcp-tls" clients.
+//
+// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to
+// prevent one cancelation from canceling all outstanding requests.
+func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
+	return c.exchangeWithConnContext(context.Background(), m, conn)
+}
+
+func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
 	if !c.SingleInflight {
 	if !c.SingleInflight {
-		return c.exchange(m, address)
+		return c.exchangeContext(ctx, m, conn)
 	}
 	}
 
 
 	q := m.Question[0]
 	q := m.Question[0]
 	key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
 	key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
 	r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) {
 	r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) {
-		return c.exchange(m, address)
+		// When we're doing singleflight we don't want one context cancelation, cancel _all_ outstanding queries.
+		// Hence we ignore the context and use Background().
+		return c.exchangeContext(context.Background(), m, conn)
 	})
 	})
 	if r != nil && shared {
 	if r != nil && shared {
 		r = r.Copy()
 		r = r.Copy()
@@ -141,16 +189,7 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
 	return r, rtt, err
 	return r, rtt, err
 }
 }
 
 
-func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
-	var co *Conn
-
-	co, err = c.Dial(a)
-
-	if err != nil {
-		return nil, 0, err
-	}
-	defer co.Close()
-
+func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
 	opt := m.IsEdns0()
 	opt := m.IsEdns0()
 	// If EDNS0 is used use that for size.
 	// If EDNS0 is used use that for size.
 	if opt != nil && opt.UDPSize() >= MinMsgSize {
 	if opt != nil && opt.UDPSize() >= MinMsgSize {
@@ -161,18 +200,41 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
 		co.UDPSize = c.UDPSize
 		co.UDPSize = c.UDPSize
 	}
 	}
 
 
-	co.TsigSecret = c.TsigSecret
-	t := time.Now()
 	// write with the appropriate write timeout
 	// write with the appropriate write timeout
-	co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout())))
+	t := time.Now()
+	writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
+	readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
+	if deadline, ok := ctx.Deadline(); ok {
+		if deadline.Before(writeDeadline) {
+			writeDeadline = deadline
+		}
+		if deadline.Before(readDeadline) {
+			readDeadline = deadline
+		}
+	}
+	co.SetWriteDeadline(writeDeadline)
+	co.SetReadDeadline(readDeadline)
+
+	co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
+
 	if err = co.WriteMsg(m); err != nil {
 	if err = co.WriteMsg(m); err != nil {
 		return nil, 0, err
 		return nil, 0, err
 	}
 	}
 
 
-	co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout())))
-	r, err = co.ReadMsg()
-	if err == nil && r.Id != m.Id {
-		err = ErrId
+	if _, ok := co.Conn.(net.PacketConn); ok {
+		for {
+			r, err = co.ReadMsg()
+			// Ignore replies with mismatched IDs because they might be
+			// responses to earlier queries that timed out.
+			if err != nil || r.Id == m.Id {
+				break
+			}
+		}
+	} else {
+		r, err = co.ReadMsg()
+		if err == nil && r.Id != m.Id {
+			err = ErrId
+		}
 	}
 	}
 	rtt = time.Since(t)
 	rtt = time.Since(t)
 	return r, rtt, err
 	return r, rtt, err
@@ -197,11 +259,15 @@ func (co *Conn) ReadMsg() (*Msg, error) {
 		return m, err
 		return m, err
 	}
 	}
 	if t := m.IsTsig(); t != nil {
 	if t := m.IsTsig(); t != nil {
-		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
-			return m, ErrSecret
+		if co.TsigProvider != nil {
+			err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false)
+		} else {
+			if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
+				return m, ErrSecret
+			}
+			// Need to work on the original message p, as that was used to calculate the tsig.
+			err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 		}
 		}
-		// Need to work on the original message p, as that was used to calculate the tsig.
-		err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 	}
 	}
 	return m, err
 	return m, err
 }
 }
@@ -279,10 +345,14 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
 	var out []byte
 	var out []byte
 	if t := m.IsTsig(); t != nil {
 	if t := m.IsTsig(); t != nil {
 		mac := ""
 		mac := ""
-		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
-			return ErrSecret
+		if co.TsigProvider != nil {
+			out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false)
+		} else {
+			if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
+				return ErrSecret
+			}
+			out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 		}
 		}
-		out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 		// Set for the next read, although only used in zone transfers
 		// Set for the next read, although only used in zone transfers
 		co.tsigRequestMAC = mac
 		co.tsigRequestMAC = mac
 	} else {
 	} else {
@@ -305,11 +375,10 @@ func (co *Conn) Write(p []byte) (int, error) {
 		return co.Conn.Write(p)
 		return co.Conn.Write(p)
 	}
 	}
 
 
-	l := make([]byte, 2)
-	binary.BigEndian.PutUint16(l, uint16(len(p)))
-
-	n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
-	return int(n), err
+	msg := make([]byte, 2+len(p))
+	binary.BigEndian.PutUint16(msg, uint16(len(p)))
+	copy(msg[2:], p)
+	return co.Conn.Write(msg)
 }
 }
 
 
 // Return the appropriate timeout for a specific request
 // Return the appropriate timeout for a specific request
@@ -345,7 +414,7 @@ func Dial(network, address string) (conn *Conn, err error) {
 func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
 func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
 	client := Client{Net: "udp"}
 	client := Client{Net: "udp"}
 	r, _, err = client.ExchangeContext(ctx, m, a)
 	r, _, err = client.ExchangeContext(ctx, m, a)
-	// ignorint rtt to leave the original ExchangeContext API unchanged, but
+	// ignoring rtt to leave the original ExchangeContext API unchanged, but
 	// this function will go away
 	// this function will go away
 	return r, err
 	return r, err
 }
 }
@@ -401,15 +470,11 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
 // context, if present. If there is both a context deadline and a configured
 // context, if present. If there is both a context deadline and a configured
 // timeout on the client, the earliest of the two takes effect.
 // timeout on the client, the earliest of the two takes effect.
 func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
 func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
-	var timeout time.Duration
-	if deadline, ok := ctx.Deadline(); !ok {
-		timeout = 0
-	} else {
-		timeout = time.Until(deadline)
+	conn, err := c.DialContext(ctx, a)
+	if err != nil {
+		return nil, 0, err
 	}
 	}
-	// not passing the context to the underlying calls, as the API does not support
-	// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
-	// TODO(tmthrgd,miekg): this is a race condition.
-	c.Dialer = &net.Dialer{Timeout: timeout}
-	return c.Exchange(m, a)
+	defer conn.Close()
+
+	return c.exchangeWithConnContext(ctx, m, conn)
 }
 }

+ 1 - 4
vendor/github.com/miekg/dns/defaults.go

@@ -349,10 +349,7 @@ func ReverseAddr(addr string) (arpa string, err error) {
 	// Add it, in reverse, to the buffer
 	// Add it, in reverse, to the buffer
 	for i := len(ip) - 1; i >= 0; i-- {
 	for i := len(ip) - 1; i >= 0; i-- {
 		v := ip[i]
 		v := ip[i]
-		buf = append(buf, hexDigit[v&0xF])
-		buf = append(buf, '.')
-		buf = append(buf, hexDigit[v>>4])
-		buf = append(buf, '.')
+		buf = append(buf, hexDigit[v&0xF], '.', hexDigit[v>>4], '.')
 	}
 	}
 	// Append "ip6.arpa." and return (buf already has the final .)
 	// Append "ip6.arpa." and return (buf already has the final .)
 	buf = append(buf, "ip6.arpa."...)
 	buf = append(buf, "ip6.arpa."...)

+ 27 - 3
vendor/github.com/miekg/dns/dns.go

@@ -1,6 +1,9 @@
 package dns
 package dns
 
 
-import "strconv"
+import (
+	"encoding/hex"
+	"strconv"
+)
 
 
 const (
 const (
 	year68     = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
 	year68     = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
@@ -111,7 +114,7 @@ func (h *RR_Header) parse(c *zlexer, origin string) *ParseError {
 
 
 // ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
 // ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
 func (rr *RFC3597) ToRFC3597(r RR) error {
 func (rr *RFC3597) ToRFC3597(r RR) error {
-	buf := make([]byte, Len(r)*2)
+	buf := make([]byte, Len(r))
 	headerEnd, off, err := packRR(r, buf, 0, compressionMap{}, false)
 	headerEnd, off, err := packRR(r, buf, 0, compressionMap{}, false)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -126,9 +129,30 @@ func (rr *RFC3597) ToRFC3597(r RR) error {
 	}
 	}
 
 
 	_, err = rr.unpack(buf, headerEnd)
 	_, err = rr.unpack(buf, headerEnd)
+	return err
+}
+
+// fromRFC3597 converts an unknown RR representation from RFC 3597 to the known RR type.
+func (rr *RFC3597) fromRFC3597(r RR) error {
+	hdr := r.Header()
+	*hdr = rr.Hdr
+
+	// Can't overflow uint16 as the length of Rdata is validated in (*RFC3597).parse.
+	// We can only get here when rr was constructed with that method.
+	hdr.Rdlength = uint16(hex.DecodedLen(len(rr.Rdata)))
+
+	if noRdata(*hdr) {
+		// Dynamic update.
+		return nil
+	}
+
+	// rr.pack requires an extra allocation and a copy so we just decode Rdata
+	// manually, it's simpler anyway.
+	msg, err := hex.DecodeString(rr.Rdata)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	return nil
+	_, err = r.unpack(msg, 0)
+	return err
 }
 }

+ 18 - 47
vendor/github.com/miekg/dns/dnssec.go

@@ -3,15 +3,14 @@ package dns
 import (
 import (
 	"bytes"
 	"bytes"
 	"crypto"
 	"crypto"
-	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/elliptic"
-	_ "crypto/md5"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/rsa"
-	_ "crypto/sha1"
-	_ "crypto/sha256"
-	_ "crypto/sha512"
+	_ "crypto/sha1"   // need its init function
+	_ "crypto/sha256" // need its init function
+	_ "crypto/sha512" // need its init function
 	"encoding/asn1"
 	"encoding/asn1"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
@@ -19,8 +18,6 @@ import (
 	"sort"
 	"sort"
 	"strings"
 	"strings"
 	"time"
 	"time"
-
-	"golang.org/x/crypto/ed25519"
 )
 )
 
 
 // DNSSEC encryption algorithm codes.
 // DNSSEC encryption algorithm codes.
@@ -318,6 +315,7 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
 		}
 		}
 
 
 		rr.Signature = toBase64(signature)
 		rr.Signature = toBase64(signature)
+		return nil
 	case RSAMD5, DSA, DSANSEC3SHA1:
 	case RSAMD5, DSA, DSANSEC3SHA1:
 		// See RFC 6944.
 		// See RFC 6944.
 		return ErrAlg
 		return ErrAlg
@@ -332,9 +330,8 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
 		}
 		}
 
 
 		rr.Signature = toBase64(signature)
 		rr.Signature = toBase64(signature)
+		return nil
 	}
 	}
-
-	return nil
 }
 }
 
 
 func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte, error) {
 func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte, error) {
@@ -346,7 +343,6 @@ func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte,
 	switch alg {
 	switch alg {
 	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 		return signature, nil
 		return signature, nil
-
 	case ECDSAP256SHA256, ECDSAP384SHA384:
 	case ECDSAP256SHA256, ECDSAP384SHA384:
 		ecdsaSignature := &struct {
 		ecdsaSignature := &struct {
 			R, S *big.Int
 			R, S *big.Int
@@ -366,25 +362,18 @@ func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte,
 		signature := intToBytes(ecdsaSignature.R, intlen)
 		signature := intToBytes(ecdsaSignature.R, intlen)
 		signature = append(signature, intToBytes(ecdsaSignature.S, intlen)...)
 		signature = append(signature, intToBytes(ecdsaSignature.S, intlen)...)
 		return signature, nil
 		return signature, nil
-
-	// There is no defined interface for what a DSA backed crypto.Signer returns
-	case DSA, DSANSEC3SHA1:
-		// 	t := divRoundUp(divRoundUp(p.PublicKey.Y.BitLen(), 8)-64, 8)
-		// 	signature := []byte{byte(t)}
-		// 	signature = append(signature, intToBytes(r1, 20)...)
-		// 	signature = append(signature, intToBytes(s1, 20)...)
-		// 	rr.Signature = signature
-
 	case ED25519:
 	case ED25519:
 		return signature, nil
 		return signature, nil
+	default:
+		return nil, ErrAlg
 	}
 	}
-
-	return nil, ErrAlg
 }
 }
 
 
 // Verify validates an RRSet with the signature and key. This is only the
 // Verify validates an RRSet with the signature and key. This is only the
 // cryptographic test, the signature validity period must be checked separately.
 // cryptographic test, the signature validity period must be checked separately.
 // This function copies the rdata of some RRs (to lowercase domain names) for the validation to work.
 // This function copies the rdata of some RRs (to lowercase domain names) for the validation to work.
+// It also checks that the Zone Key bit (RFC 4034 2.1.1) is set on the DNSKEY
+// and that the Protocol field is set to 3 (RFC 4034 2.1.2).
 func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 	// First the easy checks
 	// First the easy checks
 	if !IsRRset(rrset) {
 	if !IsRRset(rrset) {
@@ -405,6 +394,12 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 	if k.Protocol != 3 {
 	if k.Protocol != 3 {
 		return ErrKey
 		return ErrKey
 	}
 	}
+	// RFC 4034 2.1.1 If bit 7 has value 0, then the DNSKEY record holds some
+	// other type of DNS public key and MUST NOT be used to verify RRSIGs that
+	// cover RRsets.
+	if k.Flags&ZONE == 0 {
+		return ErrKey
+	}
 
 
 	// IsRRset checked that we have at least one RR and that the RRs in
 	// IsRRset checked that we have at least one RR and that the RRs in
 	// the set have consistent type, class, and name. Also check that type and
 	// the set have consistent type, class, and name. Also check that type and
@@ -448,7 +443,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 	}
 	}
 
 
 	switch rr.Algorithm {
 	switch rr.Algorithm {
-	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512, RSAMD5:
+	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 		// TODO(mg): this can be done quicker, ie. cache the pubkey data somewhere??
 		// TODO(mg): this can be done quicker, ie. cache the pubkey data somewhere??
 		pubkey := k.publicKeyRSA() // Get the key
 		pubkey := k.publicKeyRSA() // Get the key
 		if pubkey == nil {
 		if pubkey == nil {
@@ -512,7 +507,7 @@ func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
 	return ti <= utc && utc <= te
 	return ti <= utc && utc <= te
 }
 }
 
 
-// Return the signatures base64 encodedig sigdata as a byte slice.
+// Return the signatures base64 encoding sigdata as a byte slice.
 func (rr *RRSIG) sigBuf() []byte {
 func (rr *RRSIG) sigBuf() []byte {
 	sigbuf, err := fromBase64([]byte(rr.Signature))
 	sigbuf, err := fromBase64([]byte(rr.Signature))
 	if err != nil {
 	if err != nil {
@@ -600,30 +595,6 @@ func (k *DNSKEY) publicKeyECDSA() *ecdsa.PublicKey {
 	return pubkey
 	return pubkey
 }
 }
 
 
-func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey {
-	keybuf, err := fromBase64([]byte(k.PublicKey))
-	if err != nil {
-		return nil
-	}
-	if len(keybuf) < 22 {
-		return nil
-	}
-	t, keybuf := int(keybuf[0]), keybuf[1:]
-	size := 64 + t*8
-	q, keybuf := keybuf[:20], keybuf[20:]
-	if len(keybuf) != 3*size {
-		return nil
-	}
-	p, keybuf := keybuf[:size], keybuf[size:]
-	g, y := keybuf[:size], keybuf[size:]
-	pubkey := new(dsa.PublicKey)
-	pubkey.Parameters.Q = new(big.Int).SetBytes(q)
-	pubkey.Parameters.P = new(big.Int).SetBytes(p)
-	pubkey.Parameters.G = new(big.Int).SetBytes(g)
-	pubkey.Y = new(big.Int).SetBytes(y)
-	return pubkey
-}
-
 func (k *DNSKEY) publicKeyED25519() ed25519.PublicKey {
 func (k *DNSKEY) publicKeyED25519() ed25519.PublicKey {
 	keybuf, err := fromBase64([]byte(k.PublicKey))
 	keybuf, err := fromBase64([]byte(k.PublicKey))
 	if err != nil {
 	if err != nil {

+ 3 - 4
vendor/github.com/miekg/dns/dnssec_keygen.go

@@ -3,12 +3,11 @@ package dns
 import (
 import (
 	"crypto"
 	"crypto"
 	"crypto/ecdsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/rsa"
 	"math/big"
 	"math/big"
-
-	"golang.org/x/crypto/ed25519"
 )
 )
 
 
 // Generate generates a DNSKEY of the given bit size.
 // Generate generates a DNSKEY of the given bit size.
@@ -19,8 +18,6 @@ import (
 // bits should be set to the size of the algorithm.
 // bits should be set to the size of the algorithm.
 func (k *DNSKEY) Generate(bits int) (crypto.PrivateKey, error) {
 func (k *DNSKEY) Generate(bits int) (crypto.PrivateKey, error) {
 	switch k.Algorithm {
 	switch k.Algorithm {
-	case RSAMD5, DSA, DSANSEC3SHA1:
-		return nil, ErrAlg
 	case RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
 	case RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
 		if bits < 512 || bits > 4096 {
 		if bits < 512 || bits > 4096 {
 			return nil, ErrKeySize
 			return nil, ErrKeySize
@@ -41,6 +38,8 @@ func (k *DNSKEY) Generate(bits int) (crypto.PrivateKey, error) {
 		if bits != 256 {
 		if bits != 256 {
 			return nil, ErrKeySize
 			return nil, ErrKeySize
 		}
 		}
+	default:
+		return nil, ErrAlg
 	}
 	}
 
 
 	switch k.Algorithm {
 	switch k.Algorithm {

+ 4 - 17
vendor/github.com/miekg/dns/dnssec_keyscan.go

@@ -4,13 +4,12 @@ import (
 	"bufio"
 	"bufio"
 	"crypto"
 	"crypto"
 	"crypto/ecdsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/rsa"
 	"crypto/rsa"
 	"io"
 	"io"
 	"math/big"
 	"math/big"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
-
-	"golang.org/x/crypto/ed25519"
 )
 )
 
 
 // NewPrivateKey returns a PrivateKey by parsing the string s.
 // NewPrivateKey returns a PrivateKey by parsing the string s.
@@ -43,15 +42,7 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er
 		return nil, ErrPrivKey
 		return nil, ErrPrivKey
 	}
 	}
 	switch uint8(algo) {
 	switch uint8(algo) {
-	case RSAMD5, DSA, DSANSEC3SHA1:
-		return nil, ErrAlg
-	case RSASHA1:
-		fallthrough
-	case RSASHA1NSEC3SHA1:
-		fallthrough
-	case RSASHA256:
-		fallthrough
-	case RSASHA512:
+	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 		priv, err := readPrivateKeyRSA(m)
 		priv, err := readPrivateKeyRSA(m)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -62,11 +53,7 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er
 		}
 		}
 		priv.PublicKey = *pub
 		priv.PublicKey = *pub
 		return priv, nil
 		return priv, nil
-	case ECCGOST:
-		return nil, ErrPrivKey
-	case ECDSAP256SHA256:
-		fallthrough
-	case ECDSAP384SHA384:
+	case ECDSAP256SHA256, ECDSAP384SHA384:
 		priv, err := readPrivateKeyECDSA(m)
 		priv, err := readPrivateKeyECDSA(m)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -80,7 +67,7 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er
 	case ED25519:
 	case ED25519:
 		return readPrivateKeyED25519(m)
 		return readPrivateKeyED25519(m)
 	default:
 	default:
-		return nil, ErrPrivKey
+		return nil, ErrAlg
 	}
 	}
 }
 }
 
 

+ 3 - 20
vendor/github.com/miekg/dns/dnssec_privkey.go

@@ -2,13 +2,11 @@ package dns
 
 
 import (
 import (
 	"crypto"
 	"crypto"
-	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/rsa"
 	"crypto/rsa"
 	"math/big"
 	"math/big"
 	"strconv"
 	"strconv"
-
-	"golang.org/x/crypto/ed25519"
 )
 )
 
 
 const format = "Private-key-format: v1.3\n"
 const format = "Private-key-format: v1.3\n"
@@ -17,8 +15,8 @@ var bigIntOne = big.NewInt(1)
 
 
 // PrivateKeyString converts a PrivateKey to a string. This string has the same
 // PrivateKeyString converts a PrivateKey to a string. This string has the same
 // format as the private-key-file of BIND9 (Private-key-format: v1.3).
 // format as the private-key-file of BIND9 (Private-key-format: v1.3).
-// It needs some info from the key (the algorithm), so its a method of the DNSKEY
-// It supports rsa.PrivateKey, ecdsa.PrivateKey and dsa.PrivateKey
+// It needs some info from the key (the algorithm), so its a method of the DNSKEY.
+// It supports *rsa.PrivateKey, *ecdsa.PrivateKey and ed25519.PrivateKey.
 func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string {
 func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string {
 	algorithm := strconv.Itoa(int(r.Algorithm))
 	algorithm := strconv.Itoa(int(r.Algorithm))
 	algorithm += " (" + AlgorithmToString[r.Algorithm] + ")"
 	algorithm += " (" + AlgorithmToString[r.Algorithm] + ")"
@@ -67,21 +65,6 @@ func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string {
 			"Algorithm: " + algorithm + "\n" +
 			"Algorithm: " + algorithm + "\n" +
 			"PrivateKey: " + private + "\n"
 			"PrivateKey: " + private + "\n"
 
 
-	case *dsa.PrivateKey:
-		T := divRoundUp(divRoundUp(p.PublicKey.Parameters.G.BitLen(), 8)-64, 8)
-		prime := toBase64(intToBytes(p.PublicKey.Parameters.P, 64+T*8))
-		subprime := toBase64(intToBytes(p.PublicKey.Parameters.Q, 20))
-		base := toBase64(intToBytes(p.PublicKey.Parameters.G, 64+T*8))
-		priv := toBase64(intToBytes(p.X, 20))
-		pub := toBase64(intToBytes(p.PublicKey.Y, 64+T*8))
-		return format +
-			"Algorithm: " + algorithm + "\n" +
-			"Prime(p): " + prime + "\n" +
-			"Subprime(q): " + subprime + "\n" +
-			"Base(g): " + base + "\n" +
-			"Private_value(x): " + priv + "\n" +
-			"Public_value(y): " + pub + "\n"
-
 	case ed25519.PrivateKey:
 	case ed25519.PrivateKey:
 		private := toBase64(p.Seed())
 		private := toBase64(p.Seed())
 		return format +
 		return format +

+ 29 - 5
vendor/github.com/miekg/dns/doc.go

@@ -159,7 +159,7 @@ shows the options you have and what functions to call.
 TRANSACTION SIGNATURE
 TRANSACTION SIGNATURE
 
 
 An TSIG or transaction signature adds a HMAC TSIG record to each message sent.
 An TSIG or transaction signature adds a HMAC TSIG record to each message sent.
-The supported algorithms include: HmacMD5, HmacSHA1, HmacSHA256 and HmacSHA512.
+The supported algorithms include: HmacSHA1, HmacSHA256 and HmacSHA512.
 
 
 Basic use pattern when querying with a TSIG name "axfr." (note that these key names
 Basic use pattern when querying with a TSIG name "axfr." (note that these key names
 must be fully qualified - as they are domain names) and the base64 secret
 must be fully qualified - as they are domain names) and the base64 secret
@@ -174,7 +174,7 @@ changes to the RRset after calling SetTsig() the signature will be incorrect.
 	c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
 	c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
 	m := new(dns.Msg)
 	m := new(dns.Msg)
 	m.SetQuestion("miek.nl.", dns.TypeMX)
 	m.SetQuestion("miek.nl.", dns.TypeMX)
-	m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
+	m.SetTsig("axfr.", dns.HmacSHA256, 300, time.Now().Unix())
 	...
 	...
 	// When sending the TSIG RR is calculated and filled in before sending
 	// When sending the TSIG RR is calculated and filled in before sending
 
 
@@ -187,13 +187,37 @@ request an AXFR for miek.nl. with TSIG key named "axfr." and secret
 	m := new(dns.Msg)
 	m := new(dns.Msg)
 	t.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
 	t.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
 	m.SetAxfr("miek.nl.")
 	m.SetAxfr("miek.nl.")
-	m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
+	m.SetTsig("axfr.", dns.HmacSHA256, 300, time.Now().Unix())
 	c, err := t.In(m, "176.58.119.54:53")
 	c, err := t.In(m, "176.58.119.54:53")
 	for r := range c { ... }
 	for r := range c { ... }
 
 
 You can now read the records from the transfer as they come in. Each envelope
 You can now read the records from the transfer as they come in. Each envelope
 is checked with TSIG. If something is not correct an error is returned.
 is checked with TSIG. If something is not correct an error is returned.
 
 
+A custom TSIG implementation can be used. This requires additional code to
+perform any session establishment and signature generation/verification. The
+client must be configured with an implementation of the TsigProvider interface:
+
+	type Provider struct{}
+
+	func (*Provider) Generate(msg []byte, tsig *dns.TSIG) ([]byte, error) {
+		// Use tsig.Hdr.Name and tsig.Algorithm in your code to
+		// generate the MAC using msg as the payload.
+	}
+
+	func (*Provider) Verify(msg []byte, tsig *dns.TSIG) error {
+		// Use tsig.Hdr.Name and tsig.Algorithm in your code to verify
+		// that msg matches the value in tsig.MAC.
+	}
+
+	c := new(dns.Client)
+	c.TsigProvider = new(Provider)
+	m := new(dns.Msg)
+	m.SetQuestion("miek.nl.", dns.TypeMX)
+	m.SetTsig(keyname, dns.HmacSHA256, 300, time.Now().Unix())
+	...
+	// TSIG RR is calculated by calling your Generate method
+
 Basic use pattern validating and replying to a message that has TSIG set.
 Basic use pattern validating and replying to a message that has TSIG set.
 
 
 	server := &dns.Server{Addr: ":53", Net: "udp"}
 	server := &dns.Server{Addr: ":53", Net: "udp"}
@@ -207,7 +231,7 @@ Basic use pattern validating and replying to a message that has TSIG set.
 		if r.IsTsig() != nil {
 		if r.IsTsig() != nil {
 			if w.TsigStatus() == nil {
 			if w.TsigStatus() == nil {
 				// *Msg r has an TSIG record and it was validated
 				// *Msg r has an TSIG record and it was validated
-				m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
+				m.SetTsig("axfr.", dns.HmacSHA256, 300, time.Now().Unix())
 			} else {
 			} else {
 				// *Msg r has an TSIG records and it was not validated
 				// *Msg r has an TSIG records and it was not validated
 			}
 			}
@@ -260,7 +284,7 @@ From RFC 2931:
     on requests and responses, and protection of the overall integrity of a response.
     on requests and responses, and protection of the overall integrity of a response.
 
 
 It works like TSIG, except that SIG(0) uses public key cryptography, instead of
 It works like TSIG, except that SIG(0) uses public key cryptography, instead of
-the shared secret approach in TSIG. Supported algorithms: DSA, ECDSAP256SHA256,
+the shared secret approach in TSIG. Supported algorithms: ECDSAP256SHA256,
 ECDSAP384SHA384, RSASHA1, RSASHA256 and RSASHA512.
 ECDSAP384SHA384, RSASHA1, RSASHA256 and RSASHA512.
 
 
 Signing subsequent messages in multi-message sessions is not implemented.
 Signing subsequent messages in multi-message sessions is not implemented.

+ 12 - 4
vendor/github.com/miekg/dns/duplicate_generate.go

@@ -30,6 +30,9 @@ func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
 	if !ok {
 	if !ok {
 		return nil, false
 		return nil, false
 	}
 	}
+	if st.NumFields() == 0 {
+		return nil, false
+	}
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 		return st, false
 		return st, false
 	}
 	}
@@ -83,10 +86,7 @@ func main() {
 	for _, name := range namedTypes {
 	for _, name := range namedTypes {
 
 
 		o := scope.Lookup(name)
 		o := scope.Lookup(name)
-		st, isEmbedded := getTypeStruct(o.Type(), scope)
-		if isEmbedded {
-			continue
-		}
+		st, _ := getTypeStruct(o.Type(), scope)
 		fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name)
 		fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name)
 		fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name)
 		fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name)
 		fmt.Fprint(b, "if !ok { return false }\n")
 		fmt.Fprint(b, "if !ok { return false }\n")
@@ -121,6 +121,14 @@ func main() {
 					continue
 					continue
 				}
 				}
 
 
+				if st.Tag(i) == `dns:"pairs"` {
+					o2(`if !areSVCBPairArraysEqual(r1.%s, r2.%s) {
+						return false
+					}`)
+
+					continue
+				}
+
 				o3(`for i := 0; i < len(r1.%s); i++ {
 				o3(`for i := 0; i < len(r1.%s); i++ {
 					if r1.%s[i] != r2.%s[i] {
 					if r1.%s[i] != r2.%s[i] {
 						return false
 						return false

+ 151 - 5
vendor/github.com/miekg/dns/edns.go

@@ -22,11 +22,47 @@ const (
 	EDNS0COOKIE       = 0xa     // EDNS0 Cookie
 	EDNS0COOKIE       = 0xa     // EDNS0 Cookie
 	EDNS0TCPKEEPALIVE = 0xb     // EDNS0 tcp keep alive (See RFC 7828)
 	EDNS0TCPKEEPALIVE = 0xb     // EDNS0 tcp keep alive (See RFC 7828)
 	EDNS0PADDING      = 0xc     // EDNS0 padding (See RFC 7830)
 	EDNS0PADDING      = 0xc     // EDNS0 padding (See RFC 7830)
+	EDNS0EDE          = 0xf     // EDNS0 extended DNS errors (See RFC 8914)
 	EDNS0LOCALSTART   = 0xFDE9  // Beginning of range reserved for local/experimental use (See RFC 6891)
 	EDNS0LOCALSTART   = 0xFDE9  // Beginning of range reserved for local/experimental use (See RFC 6891)
 	EDNS0LOCALEND     = 0xFFFE  // End of range reserved for local/experimental use (See RFC 6891)
 	EDNS0LOCALEND     = 0xFFFE  // End of range reserved for local/experimental use (See RFC 6891)
 	_DO               = 1 << 15 // DNSSEC OK
 	_DO               = 1 << 15 // DNSSEC OK
 )
 )
 
 
+// makeDataOpt is used to unpack the EDNS0 option(s) from a message.
+func makeDataOpt(code uint16) EDNS0 {
+	// All the EDNS0.* constants above need to be in this switch.
+	switch code {
+	case EDNS0LLQ:
+		return new(EDNS0_LLQ)
+	case EDNS0UL:
+		return new(EDNS0_UL)
+	case EDNS0NSID:
+		return new(EDNS0_NSID)
+	case EDNS0DAU:
+		return new(EDNS0_DAU)
+	case EDNS0DHU:
+		return new(EDNS0_DHU)
+	case EDNS0N3U:
+		return new(EDNS0_N3U)
+	case EDNS0SUBNET:
+		return new(EDNS0_SUBNET)
+	case EDNS0EXPIRE:
+		return new(EDNS0_EXPIRE)
+	case EDNS0COOKIE:
+		return new(EDNS0_COOKIE)
+	case EDNS0TCPKEEPALIVE:
+		return new(EDNS0_TCP_KEEPALIVE)
+	case EDNS0PADDING:
+		return new(EDNS0_PADDING)
+	case EDNS0EDE:
+		return new(EDNS0_EDE)
+	default:
+		e := new(EDNS0_LOCAL)
+		e.Code = code
+		return e
+	}
+}
+
 // OPT is the EDNS0 RR appended to messages to convey extra (meta) information.
 // OPT is the EDNS0 RR appended to messages to convey extra (meta) information.
 // See RFC 6891.
 // See RFC 6891.
 type OPT struct {
 type OPT struct {
@@ -73,6 +109,8 @@ func (rr *OPT) String() string {
 			s += "\n; LOCAL OPT: " + o.String()
 			s += "\n; LOCAL OPT: " + o.String()
 		case *EDNS0_PADDING:
 		case *EDNS0_PADDING:
 			s += "\n; PADDING: " + o.String()
 			s += "\n; PADDING: " + o.String()
+		case *EDNS0_EDE:
+			s += "\n; EDE: " + o.String()
 		}
 		}
 	}
 	}
 	return s
 	return s
@@ -88,11 +126,11 @@ func (rr *OPT) len(off int, compression map[string]struct{}) int {
 	return l
 	return l
 }
 }
 
 
-func (rr *OPT) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on OPT")
+func (*OPT) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "OPT records do not have a presentation format"}
 }
 }
 
 
-func (r1 *OPT) isDuplicate(r2 RR) bool { return false }
+func (rr *OPT) isDuplicate(r2 RR) bool { return false }
 
 
 // return the old value -> delete SetVersion?
 // return the old value -> delete SetVersion?
 
 
@@ -148,6 +186,16 @@ func (rr *OPT) SetDo(do ...bool) {
 	}
 	}
 }
 }
 
 
+// Z returns the Z part of the OPT RR as a uint16 with only the 15 least significant bits used.
+func (rr *OPT) Z() uint16 {
+	return uint16(rr.Hdr.Ttl & 0x7FFF)
+}
+
+// SetZ sets the Z part of the OPT RR, note only the 15 least significant bits of z are used.
+func (rr *OPT) SetZ(z uint16) {
+	rr.Hdr.Ttl = rr.Hdr.Ttl&^0x7FFF | uint32(z&0x7FFF)
+}
+
 // EDNS0 defines an EDNS0 Option. An OPT RR can have multiple options appended to it.
 // EDNS0 defines an EDNS0 Option. An OPT RR can have multiple options appended to it.
 type EDNS0 interface {
 type EDNS0 interface {
 	// Option returns the option code for the option.
 	// Option returns the option code for the option.
@@ -452,7 +500,7 @@ func (e *EDNS0_LLQ) copy() EDNS0 {
 	return &EDNS0_LLQ{e.Code, e.Version, e.Opcode, e.Error, e.Id, e.LeaseLife}
 	return &EDNS0_LLQ{e.Code, e.Version, e.Opcode, e.Error, e.Id, e.LeaseLife}
 }
 }
 
 
-// EDNS0_DUA implements the EDNS0 "DNSSEC Algorithm Understood" option. See RFC 6975.
+// EDNS0_DAU implements the EDNS0 "DNSSEC Algorithm Understood" option. See RFC 6975.
 type EDNS0_DAU struct {
 type EDNS0_DAU struct {
 	Code    uint16 // Always EDNS0DAU
 	Code    uint16 // Always EDNS0DAU
 	AlgCode []uint8
 	AlgCode []uint8
@@ -525,7 +573,7 @@ func (e *EDNS0_N3U) String() string {
 }
 }
 func (e *EDNS0_N3U) copy() EDNS0 { return &EDNS0_N3U{e.Code, e.AlgCode} }
 func (e *EDNS0_N3U) copy() EDNS0 { return &EDNS0_N3U{e.Code, e.AlgCode} }
 
 
-// EDNS0_EXPIRE implementes the EDNS0 option as described in RFC 7314.
+// EDNS0_EXPIRE implements the EDNS0 option as described in RFC 7314.
 type EDNS0_EXPIRE struct {
 type EDNS0_EXPIRE struct {
 	Code   uint16 // Always EDNS0EXPIRE
 	Code   uint16 // Always EDNS0EXPIRE
 	Expire uint32
 	Expire uint32
@@ -673,3 +721,101 @@ func (e *EDNS0_PADDING) copy() EDNS0 {
 	copy(b, e.Padding)
 	copy(b, e.Padding)
 	return &EDNS0_PADDING{b}
 	return &EDNS0_PADDING{b}
 }
 }
+
+// Extended DNS Error Codes (RFC 8914).
+const (
+	ExtendedErrorCodeOther uint16 = iota
+	ExtendedErrorCodeUnsupportedDNSKEYAlgorithm
+	ExtendedErrorCodeUnsupportedDSDigestType
+	ExtendedErrorCodeStaleAnswer
+	ExtendedErrorCodeForgedAnswer
+	ExtendedErrorCodeDNSSECIndeterminate
+	ExtendedErrorCodeDNSBogus
+	ExtendedErrorCodeSignatureExpired
+	ExtendedErrorCodeSignatureNotYetValid
+	ExtendedErrorCodeDNSKEYMissing
+	ExtendedErrorCodeRRSIGsMissing
+	ExtendedErrorCodeNoZoneKeyBitSet
+	ExtendedErrorCodeNSECMissing
+	ExtendedErrorCodeCachedError
+	ExtendedErrorCodeNotReady
+	ExtendedErrorCodeBlocked
+	ExtendedErrorCodeCensored
+	ExtendedErrorCodeFiltered
+	ExtendedErrorCodeProhibited
+	ExtendedErrorCodeStaleNXDOMAINAnswer
+	ExtendedErrorCodeNotAuthoritative
+	ExtendedErrorCodeNotSupported
+	ExtendedErrorCodeNoReachableAuthority
+	ExtendedErrorCodeNetworkError
+	ExtendedErrorCodeInvalidData
+)
+
+// ExtendedErrorCodeToString maps extended error info codes to a human readable
+// description.
+var ExtendedErrorCodeToString = map[uint16]string{
+	ExtendedErrorCodeOther:                      "Other",
+	ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: "Unsupported DNSKEY Algorithm",
+	ExtendedErrorCodeUnsupportedDSDigestType:    "Unsupported DS Digest Type",
+	ExtendedErrorCodeStaleAnswer:                "Stale Answer",
+	ExtendedErrorCodeForgedAnswer:               "Forged Answer",
+	ExtendedErrorCodeDNSSECIndeterminate:        "DNSSEC Indeterminate",
+	ExtendedErrorCodeDNSBogus:                   "DNSSEC Bogus",
+	ExtendedErrorCodeSignatureExpired:           "Signature Expired",
+	ExtendedErrorCodeSignatureNotYetValid:       "Signature Not Yet Valid",
+	ExtendedErrorCodeDNSKEYMissing:              "DNSKEY Missing",
+	ExtendedErrorCodeRRSIGsMissing:              "RRSIGs Missing",
+	ExtendedErrorCodeNoZoneKeyBitSet:            "No Zone Key Bit Set",
+	ExtendedErrorCodeNSECMissing:                "NSEC Missing",
+	ExtendedErrorCodeCachedError:                "Cached Error",
+	ExtendedErrorCodeNotReady:                   "Not Ready",
+	ExtendedErrorCodeBlocked:                    "Blocked",
+	ExtendedErrorCodeCensored:                   "Censored",
+	ExtendedErrorCodeFiltered:                   "Filtered",
+	ExtendedErrorCodeProhibited:                 "Prohibited",
+	ExtendedErrorCodeStaleNXDOMAINAnswer:        "Stale NXDOMAIN Answer",
+	ExtendedErrorCodeNotAuthoritative:           "Not Authoritative",
+	ExtendedErrorCodeNotSupported:               "Not Supported",
+	ExtendedErrorCodeNoReachableAuthority:       "No Reachable Authority",
+	ExtendedErrorCodeNetworkError:               "Network Error",
+	ExtendedErrorCodeInvalidData:                "Invalid Data",
+}
+
+// StringToExtendedErrorCode is a map from human readable descriptions to
+// extended error info codes.
+var StringToExtendedErrorCode = reverseInt16(ExtendedErrorCodeToString)
+
+// EDNS0_EDE option is used to return additional information about the cause of
+// DNS errors.
+type EDNS0_EDE struct {
+	InfoCode  uint16
+	ExtraText string
+}
+
+// Option implements the EDNS0 interface.
+func (e *EDNS0_EDE) Option() uint16 { return EDNS0EDE }
+func (e *EDNS0_EDE) copy() EDNS0    { return &EDNS0_EDE{e.InfoCode, e.ExtraText} }
+
+func (e *EDNS0_EDE) String() string {
+	info := strconv.FormatUint(uint64(e.InfoCode), 10)
+	if s, ok := ExtendedErrorCodeToString[e.InfoCode]; ok {
+		info += fmt.Sprintf(" (%s)", s)
+	}
+	return fmt.Sprintf("%s: (%s)", info, e.ExtraText)
+}
+
+func (e *EDNS0_EDE) pack() ([]byte, error) {
+	b := make([]byte, 2+len(e.ExtraText))
+	binary.BigEndian.PutUint16(b[0:], e.InfoCode)
+	copy(b[2:], []byte(e.ExtraText))
+	return b, nil
+}
+
+func (e *EDNS0_EDE) unpack(b []byte) error {
+	if len(b) < 2 {
+		return ErrBuf
+	}
+	e.InfoCode = binary.BigEndian.Uint16(b[0:])
+	e.ExtraText = string(b[2:])
+	return nil
+}

+ 12 - 12
vendor/github.com/miekg/dns/generate.go

@@ -20,13 +20,13 @@ import (
 // of $ after that are interpreted.
 // of $ after that are interpreted.
 func (zp *ZoneParser) generate(l lex) (RR, bool) {
 func (zp *ZoneParser) generate(l lex) (RR, bool) {
 	token := l.token
 	token := l.token
-	step := 1
+	step := int64(1)
 	if i := strings.IndexByte(token, '/'); i >= 0 {
 	if i := strings.IndexByte(token, '/'); i >= 0 {
 		if i+1 == len(token) {
 		if i+1 == len(token) {
 			return zp.setParseError("bad step in $GENERATE range", l)
 			return zp.setParseError("bad step in $GENERATE range", l)
 		}
 		}
 
 
-		s, err := strconv.Atoi(token[i+1:])
+		s, err := strconv.ParseInt(token[i+1:], 10, 64)
 		if err != nil || s <= 0 {
 		if err != nil || s <= 0 {
 			return zp.setParseError("bad step in $GENERATE range", l)
 			return zp.setParseError("bad step in $GENERATE range", l)
 		}
 		}
@@ -40,12 +40,12 @@ func (zp *ZoneParser) generate(l lex) (RR, bool) {
 		return zp.setParseError("bad start-stop in $GENERATE range", l)
 		return zp.setParseError("bad start-stop in $GENERATE range", l)
 	}
 	}
 
 
-	start, err := strconv.Atoi(sx[0])
+	start, err := strconv.ParseInt(sx[0], 10, 64)
 	if err != nil {
 	if err != nil {
 		return zp.setParseError("bad start in $GENERATE range", l)
 		return zp.setParseError("bad start in $GENERATE range", l)
 	}
 	}
 
 
-	end, err := strconv.Atoi(sx[1])
+	end, err := strconv.ParseInt(sx[1], 10, 64)
 	if err != nil {
 	if err != nil {
 		return zp.setParseError("bad stop in $GENERATE range", l)
 		return zp.setParseError("bad stop in $GENERATE range", l)
 	}
 	}
@@ -94,10 +94,10 @@ type generateReader struct {
 	s  string
 	s  string
 	si int
 	si int
 
 
-	cur   int
-	start int
-	end   int
-	step  int
+	cur   int64
+	start int64
+	end   int64
+	step  int64
 
 
 	mod bytes.Buffer
 	mod bytes.Buffer
 
 
@@ -173,7 +173,7 @@ func (r *generateReader) ReadByte() (byte, error) {
 			return '$', nil
 			return '$', nil
 		}
 		}
 
 
-		var offset int
+		var offset int64
 
 
 		// Search for { and }
 		// Search for { and }
 		if r.s[si+1] == '{' {
 		if r.s[si+1] == '{' {
@@ -208,7 +208,7 @@ func (r *generateReader) ReadByte() (byte, error) {
 }
 }
 
 
 // Convert a $GENERATE modifier 0,0,d to something Printf can deal with.
 // Convert a $GENERATE modifier 0,0,d to something Printf can deal with.
-func modToPrintf(s string) (string, int, string) {
+func modToPrintf(s string) (string, int64, string) {
 	// Modifier is { offset [ ,width [ ,base ] ] } - provide default
 	// Modifier is { offset [ ,width [ ,base ] ] } - provide default
 	// values for optional width and type, if necessary.
 	// values for optional width and type, if necessary.
 	var offStr, widthStr, base string
 	var offStr, widthStr, base string
@@ -229,12 +229,12 @@ func modToPrintf(s string) (string, int, string) {
 		return "", 0, "bad base in $GENERATE"
 		return "", 0, "bad base in $GENERATE"
 	}
 	}
 
 
-	offset, err := strconv.Atoi(offStr)
+	offset, err := strconv.ParseInt(offStr, 10, 64)
 	if err != nil {
 	if err != nil {
 		return "", 0, "bad offset in $GENERATE"
 		return "", 0, "bad offset in $GENERATE"
 	}
 	}
 
 
-	width, err := strconv.Atoi(widthStr)
+	width, err := strconv.ParseInt(widthStr, 10, 64)
 	if err != nil || width < 0 || width > 255 {
 	if err != nil || width < 0 || width > 255 {
 		return "", 0, "bad width in $GENERATE"
 		return "", 0, "bad width in $GENERATE"
 	}
 	}

+ 4 - 6
vendor/github.com/miekg/dns/go.mod

@@ -1,11 +1,9 @@
 module github.com/miekg/dns
 module github.com/miekg/dns
 
 
-go 1.12
+go 1.14
 
 
 require (
 require (
-	golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
-	golang.org/x/net v0.0.0-20190923162816-aa69164e4478
-	golang.org/x/sync v0.0.0-20190423024810-112230192c58
-	golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe
-	golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 // indirect
+	golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
+	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
+	golang.org/x/sys v0.0.0-20210303074136-134d130e1a04
 )
 )

+ 9 - 38
vendor/github.com/miekg/dns/go.sum

@@ -1,39 +1,10 @@
-golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 h1:Vk3wNqEZwyGyei9yq5ekj7frek2u7HUfffJ1/opblzc=
-golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM=
-golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392 h1:ACG4HJsFiNMf47Y4PeRoebLNy/2lXT9EtprMuTFWt1M=
-golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/net v0.0.0-20180926154720-4dfa2610cdf3 h1:dgd4x4kJt7G4k4m93AYLzM8Ni6h2qLTfh9n9vXJT3/0=
-golang.org/x/net v0.0.0-20180926154720-4dfa2610cdf3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM=
-golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g=
-golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
-golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611 h1:O33LKL7WyJgjN9CvxfTIomjIClbd/Kq86/iipowHQU0=
-golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190904154756-749cb33beabd h1:DBH9mDw0zluJT/R+nGuV3jWFWLFaHyYZWD4tOT+cjn0=
-golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe h1:6fAMxZRR6sl1Uq8U61gxU+kPTs2tR8uOySCbBP7BN/M=
-golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g=
+golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA=
-golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

+ 1 - 1
vendor/github.com/miekg/dns/labels.go

@@ -10,7 +10,7 @@ package dns
 // escaped dots (\.) for instance.
 // escaped dots (\.) for instance.
 // s must be a syntactically valid domain name, see IsDomainName.
 // s must be a syntactically valid domain name, see IsDomainName.
 func SplitDomainName(s string) (labels []string) {
 func SplitDomainName(s string) (labels []string) {
-	if len(s) == 0 {
+	if s == "" {
 		return nil
 		return nil
 	}
 	}
 	fqdnEnd := 0 // offset of the final '.' or the length of the name
 	fqdnEnd := 0 // offset of the final '.' or the length of the name

+ 0 - 0
vendor/github.com/miekg/dns/listen_go_not111.go → vendor/github.com/miekg/dns/listen_no_reuseport.go


+ 0 - 0
vendor/github.com/miekg/dns/listen_go111.go → vendor/github.com/miekg/dns/listen_reuseport.go


+ 15 - 13
vendor/github.com/miekg/dns/msg.go

@@ -398,17 +398,12 @@ Loop:
 				return "", lenmsg, ErrLongDomain
 				return "", lenmsg, ErrLongDomain
 			}
 			}
 			for _, b := range msg[off : off+c] {
 			for _, b := range msg[off : off+c] {
-				switch b {
-				case '.', '(', ')', ';', ' ', '@':
-					fallthrough
-				case '"', '\\':
+				if isDomainNameLabelSpecial(b) {
 					s = append(s, '\\', b)
 					s = append(s, '\\', b)
-				default:
-					if b < ' ' || b > '~' { // unprintable, use \DDD
-						s = append(s, escapeByte(b)...)
-					} else {
-						s = append(s, b)
-					}
+				} else if b < ' ' || b > '~' {
+					s = append(s, escapeByte(b)...)
+				} else {
+					s = append(s, b)
 				}
 				}
 			}
 			}
 			s = append(s, '.')
 			s = append(s, '.')
@@ -629,11 +624,18 @@ func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err
 		rr = &RFC3597{Hdr: h}
 		rr = &RFC3597{Hdr: h}
 	}
 	}
 
 
-	if noRdata(h) {
-		return rr, off, nil
+	if off < 0 || off > len(msg) {
+		return &h, off, &Error{err: "bad off"}
 	}
 	}
 
 
 	end := off + int(h.Rdlength)
 	end := off + int(h.Rdlength)
+	if end < off || end > len(msg) {
+		return &h, end, &Error{err: "bad rdlength"}
+	}
+
+	if noRdata(h) {
+		return rr, off, nil
+	}
 
 
 	off, err = rr.unpack(msg, off)
 	off, err = rr.unpack(msg, off)
 	if err != nil {
 	if err != nil {
@@ -740,7 +742,7 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compression
 	}
 	}
 
 
 	// Set extended rcode unconditionally if we have an opt, this will allow
 	// Set extended rcode unconditionally if we have an opt, this will allow
-	// reseting the extended rcode bits if they need to.
+	// resetting the extended rcode bits if they need to.
 	if opt := dns.IsEdns0(); opt != nil {
 	if opt := dns.IsEdns0(); opt != nil {
 		opt.SetExtendedRcode(uint16(dns.Rcode))
 		opt.SetExtendedRcode(uint16(dns.Rcode))
 	} else if dns.Rcode > 0xF {
 	} else if dns.Rcode > 0xF {

+ 7 - 0
vendor/github.com/miekg/dns/msg_generate.go

@@ -35,6 +35,9 @@ func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
 	if !ok {
 	if !ok {
 		return nil, false
 		return nil, false
 	}
 	}
+	if st.NumFields() == 0 {
+		return nil, false
+	}
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 		return st, false
 		return st, false
 	}
 	}
@@ -110,6 +113,8 @@ return off, err
 					o("off, err = packDataOpt(rr.%s, msg, off)\n")
 					o("off, err = packDataOpt(rr.%s, msg, off)\n")
 				case `dns:"nsec"`:
 				case `dns:"nsec"`:
 					o("off, err = packDataNsec(rr.%s, msg, off)\n")
 					o("off, err = packDataNsec(rr.%s, msg, off)\n")
+				case `dns:"pairs"`:
+					o("off, err = packDataSVCB(rr.%s, msg, off)\n")
 				case `dns:"domain-name"`:
 				case `dns:"domain-name"`:
 					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
 					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
 				case `dns:"apl"`:
 				case `dns:"apl"`:
@@ -236,6 +241,8 @@ return off, err
 					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
 					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
 				case `dns:"nsec"`:
 				case `dns:"nsec"`:
 					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
 					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
+				case `dns:"pairs"`:
+					o("rr.%s, off, err = unpackDataSVCB(msg, off)\n")
 				case `dns:"domain-name"`:
 				case `dns:"domain-name"`:
 					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
 					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
 				case `dns:"apl"`:
 				case `dns:"apl"`:

+ 73 - 82
vendor/github.com/miekg/dns/msg_helpers.go

@@ -6,6 +6,7 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
 	"net"
 	"net"
+	"sort"
 	"strings"
 	"strings"
 )
 )
 
 
@@ -423,86 +424,12 @@ Option:
 	if off+int(optlen) > len(msg) {
 	if off+int(optlen) > len(msg) {
 		return nil, len(msg), &Error{err: "overflow unpacking opt"}
 		return nil, len(msg), &Error{err: "overflow unpacking opt"}
 	}
 	}
-	switch code {
-	case EDNS0NSID:
-		e := new(EDNS0_NSID)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0SUBNET:
-		e := new(EDNS0_SUBNET)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0COOKIE:
-		e := new(EDNS0_COOKIE)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0EXPIRE:
-		e := new(EDNS0_EXPIRE)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0UL:
-		e := new(EDNS0_UL)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0LLQ:
-		e := new(EDNS0_LLQ)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0DAU:
-		e := new(EDNS0_DAU)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0DHU:
-		e := new(EDNS0_DHU)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0N3U:
-		e := new(EDNS0_N3U)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0PADDING:
-		e := new(EDNS0_PADDING)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	default:
-		e := new(EDNS0_LOCAL)
-		e.Code = code
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
+	e := makeDataOpt(code)
+	if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
+		return nil, len(msg), err
 	}
 	}
+	edns = append(edns, e)
+	off += int(optlen)
 
 
 	if off < len(msg) {
 	if off < len(msg) {
 		goto Option
 		goto Option
@@ -521,9 +448,7 @@ func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
 		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
 		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
 		off += 4
 		off += 4
 		if off+len(b) > len(msg) {
 		if off+len(b) > len(msg) {
-			copy(msg[off:], b)
-			off = len(msg)
-			continue
+			return len(msg), &Error{err: "overflow packing opt"}
 		}
 		}
 		// Actual data
 		// Actual data
 		copy(msg[off:off+len(b)], b)
 		copy(msg[off:off+len(b)], b)
@@ -659,6 +584,65 @@ func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
 	return off, nil
 	return off, nil
 }
 }
 
 
+func unpackDataSVCB(msg []byte, off int) ([]SVCBKeyValue, int, error) {
+	var xs []SVCBKeyValue
+	var code uint16
+	var length uint16
+	var err error
+	for off < len(msg) {
+		code, off, err = unpackUint16(msg, off)
+		if err != nil {
+			return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
+		}
+		length, off, err = unpackUint16(msg, off)
+		if err != nil || off+int(length) > len(msg) {
+			return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
+		}
+		e := makeSVCBKeyValue(SVCBKey(code))
+		if e == nil {
+			return nil, len(msg), &Error{err: "bad SVCB key"}
+		}
+		if err := e.unpack(msg[off : off+int(length)]); err != nil {
+			return nil, len(msg), err
+		}
+		if len(xs) > 0 && e.Key() <= xs[len(xs)-1].Key() {
+			return nil, len(msg), &Error{err: "SVCB keys not in strictly increasing order"}
+		}
+		xs = append(xs, e)
+		off += int(length)
+	}
+	return xs, off, nil
+}
+
+func packDataSVCB(pairs []SVCBKeyValue, msg []byte, off int) (int, error) {
+	pairs = append([]SVCBKeyValue(nil), pairs...)
+	sort.Slice(pairs, func(i, j int) bool {
+		return pairs[i].Key() < pairs[j].Key()
+	})
+	prev := svcb_RESERVED
+	for _, el := range pairs {
+		if el.Key() == prev {
+			return len(msg), &Error{err: "repeated SVCB keys are not allowed"}
+		}
+		prev = el.Key()
+		packed, err := el.pack()
+		if err != nil {
+			return len(msg), err
+		}
+		off, err = packUint16(uint16(el.Key()), msg, off)
+		if err != nil {
+			return len(msg), &Error{err: "overflow packing SVCB"}
+		}
+		off, err = packUint16(uint16(len(packed)), msg, off)
+		if err != nil || off+len(packed) > len(msg) {
+			return len(msg), &Error{err: "overflow packing SVCB"}
+		}
+		copy(msg[off:off+len(packed)], packed)
+		off += len(packed)
+	}
+	return off, nil
+}
+
 func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
 func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
 	var (
 	var (
 		servers []string
 		servers []string
@@ -730,6 +714,13 @@ func packDataAplPrefix(p *APLPrefix, msg []byte, off int) (int, error) {
 	if p.Negation {
 	if p.Negation {
 		n = 0x80
 		n = 0x80
 	}
 	}
+
+	// trim trailing zero bytes as specified in RFC3123 Sections 4.1 and 4.2.
+	i := len(addr) - 1
+	for ; i >= 0 && addr[i] == 0; i-- {
+	}
+	addr = addr[:i+1]
+
 	adflen := uint8(len(addr)) & 0x7f
 	adflen := uint8(len(addr)) & 0x7f
 	off, err = packUint8(n|adflen, msg, off)
 	off, err = packUint8(n|adflen, msg, off)
 	if err != nil {
 	if err != nil {

+ 11 - 5
vendor/github.com/miekg/dns/msg_truncate.go

@@ -8,8 +8,14 @@ package dns
 // record adding as many records as possible without exceeding the
 // record adding as many records as possible without exceeding the
 // requested buffer size.
 // requested buffer size.
 //
 //
+// If the message fits within the requested size without compression,
+// Truncate will set the message's Compress attribute to false. It is
+// the caller's responsibility to set it back to true if they wish to
+// compress the payload regardless of size.
+//
 // The TC bit will be set if any records were excluded from the message.
 // The TC bit will be set if any records were excluded from the message.
-// This indicates to that the client should retry over TCP.
+// If the TC bit is already set on the message it will be retained.
+// TC indicates that the client should retry over TCP.
 //
 //
 // According to RFC 2181, the TC bit should only be set if not all of the
 // According to RFC 2181, the TC bit should only be set if not all of the
 // "required" RRs can be included in the response. Unfortunately, we have
 // "required" RRs can be included in the response. Unfortunately, we have
@@ -28,11 +34,11 @@ func (dns *Msg) Truncate(size int) {
 	}
 	}
 
 
 	// RFC 6891 mandates that the payload size in an OPT record
 	// RFC 6891 mandates that the payload size in an OPT record
-	// less than 512 bytes must be treated as equal to 512 bytes.
+	// less than 512 (MinMsgSize) bytes must be treated as equal to 512 bytes.
 	//
 	//
 	// For ease of use, we impose that restriction here.
 	// For ease of use, we impose that restriction here.
-	if size < 512 {
-		size = 512
+	if size < MinMsgSize {
+		size = MinMsgSize
 	}
 	}
 
 
 	l := msgLenWithCompressionMap(dns, nil) // uncompressed length
 	l := msgLenWithCompressionMap(dns, nil) // uncompressed length
@@ -77,7 +83,7 @@ func (dns *Msg) Truncate(size int) {
 	}
 	}
 
 
 	// See the function documentation for when we set this.
 	// See the function documentation for when we set this.
-	dns.Truncated = len(dns.Answer) > numAnswer ||
+	dns.Truncated = dns.Truncated || len(dns.Answer) > numAnswer ||
 		len(dns.Ns) > numNS || len(dns.Extra) > numExtra
 		len(dns.Ns) > numNS || len(dns.Extra) > numExtra
 
 
 	dns.Answer = dns.Answer[:numAnswer]
 	dns.Answer = dns.Answer[:numAnswer]

+ 2 - 2
vendor/github.com/miekg/dns/privaterr.go

@@ -6,7 +6,7 @@ import "strings"
 // RFC 6895. This allows one to experiment with new RR types, without requesting an
 // RFC 6895. This allows one to experiment with new RR types, without requesting an
 // official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove.
 // official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove.
 type PrivateRdata interface {
 type PrivateRdata interface {
-	// String returns the text presentaton of the Rdata of the Private RR.
+	// String returns the text presentation of the Rdata of the Private RR.
 	String() string
 	String() string
 	// Parse parses the Rdata of the private RR.
 	// Parse parses the Rdata of the private RR.
 	Parse([]string) error
 	Parse([]string) error
@@ -90,7 +90,7 @@ Fetch:
 	return nil
 	return nil
 }
 }
 
 
-func (r1 *PrivateRR) isDuplicate(r2 RR) bool { return false }
+func (r *PrivateRR) isDuplicate(r2 RR) bool { return false }
 
 
 // PrivateHandle registers a private resource record type. It requires
 // PrivateHandle registers a private resource record type. It requires
 // string and numeric representation of private RR type and generator function as argument.
 // string and numeric representation of private RR type and generator function as argument.

+ 59 - 22
vendor/github.com/miekg/dns/scan.go

@@ -150,6 +150,9 @@ func ReadRR(r io.Reader, file string) (RR, error) {
 // The text "; this is comment" is returned from Comment. Comments inside
 // The text "; this is comment" is returned from Comment. Comments inside
 // the RR are returned concatenated along with the RR. Comments on a line
 // the RR are returned concatenated along with the RR. Comments on a line
 // by themselves are discarded.
 // by themselves are discarded.
+//
+// Callers should not assume all returned data in an Resource Record is
+// syntactically correct, e.g. illegal base64 in RRSIGs will be returned as-is.
 type ZoneParser struct {
 type ZoneParser struct {
 	c *zlexer
 	c *zlexer
 
 
@@ -577,10 +580,23 @@ func (zp *ZoneParser) Next() (RR, bool) {
 
 
 			st = zExpectRdata
 			st = zExpectRdata
 		case zExpectRdata:
 		case zExpectRdata:
-			var rr RR
-			if newFn, ok := TypeToRR[h.Rrtype]; ok && canParseAsRR(h.Rrtype) {
+			var (
+				rr             RR
+				parseAsRFC3597 bool
+			)
+			if newFn, ok := TypeToRR[h.Rrtype]; ok {
 				rr = newFn()
 				rr = newFn()
 				*rr.Header() = *h
 				*rr.Header() = *h
+
+				// We may be parsing a known RR type using the RFC3597 format.
+				// If so, we handle that here in a generic way.
+				//
+				// This is also true for PrivateRR types which will have the
+				// RFC3597 parsing done for them and the Unpack method called
+				// to populate the RR instead of simply deferring to Parse.
+				if zp.c.Peek().token == "\\#" {
+					parseAsRFC3597 = true
+				}
 			} else {
 			} else {
 				rr = &RFC3597{Hdr: *h}
 				rr = &RFC3597{Hdr: *h}
 			}
 			}
@@ -600,13 +616,18 @@ func (zp *ZoneParser) Next() (RR, bool) {
 				return zp.setParseError("unexpected newline", l)
 				return zp.setParseError("unexpected newline", l)
 			}
 			}
 
 
-			if err := rr.parse(zp.c, zp.origin); err != nil {
+			parseAsRR := rr
+			if parseAsRFC3597 {
+				parseAsRR = &RFC3597{Hdr: *h}
+			}
+
+			if err := parseAsRR.parse(zp.c, zp.origin); err != nil {
 				// err is a concrete *ParseError without the file field set.
 				// err is a concrete *ParseError without the file field set.
 				// The setParseError call below will construct a new
 				// The setParseError call below will construct a new
 				// *ParseError with file set to zp.file.
 				// *ParseError with file set to zp.file.
 
 
-				// If err.lex is nil than we have encounter an unknown RR type
-				// in that case we substitute our current lex token.
+				// err.lex may be nil in which case we substitute our current
+				// lex token.
 				if err.lex == (lex{}) {
 				if err.lex == (lex{}) {
 					return zp.setParseError(err.err, l)
 					return zp.setParseError(err.err, l)
 				}
 				}
@@ -614,6 +635,13 @@ func (zp *ZoneParser) Next() (RR, bool) {
 				return zp.setParseError(err.err, err.lex)
 				return zp.setParseError(err.err, err.lex)
 			}
 			}
 
 
+			if parseAsRFC3597 {
+				err := parseAsRR.(*RFC3597).fromRFC3597(rr)
+				if err != nil {
+					return zp.setParseError(err.Error(), l)
+				}
+			}
+
 			return rr, true
 			return rr, true
 		}
 		}
 	}
 	}
@@ -623,18 +651,6 @@ func (zp *ZoneParser) Next() (RR, bool) {
 	return nil, false
 	return nil, false
 }
 }
 
 
-// canParseAsRR returns true if the record type can be parsed as a
-// concrete RR. It blacklists certain record types that must be parsed
-// according to RFC 3597 because they lack a presentation format.
-func canParseAsRR(rrtype uint16) bool {
-	switch rrtype {
-	case TypeANY, TypeNULL, TypeOPT, TypeTSIG:
-		return false
-	default:
-		return true
-	}
-}
-
 type zlexer struct {
 type zlexer struct {
 	br io.ByteReader
 	br io.ByteReader
 
 
@@ -1210,11 +1226,29 @@ func stringToCm(token string) (e, m uint8, ok bool) {
 		if cmeters, err = strconv.Atoi(s[1]); err != nil {
 		if cmeters, err = strconv.Atoi(s[1]); err != nil {
 			return
 			return
 		}
 		}
+		// There's no point in having more than 2 digits in this part, and would rather make the implementation complicated ('123' should be treated as '12').
+		// So we simply reject it.
+		// We also make sure the first character is a digit to reject '+-' signs.
+		if len(s[1]) > 2 || s[1][0] < '0' || s[1][0] > '9' {
+			return
+		}
+		if len(s[1]) == 1 {
+			// 'nn.1' must be treated as 'nn-meters and 10cm, not 1cm.
+			cmeters *= 10
+		}
+		if s[0] == "" {
+			// This will allow omitting the 'meter' part, like .01 (meaning 0.01m = 1cm).
+			break
+		}
 		fallthrough
 		fallthrough
 	case 1:
 	case 1:
 		if meters, err = strconv.Atoi(s[0]); err != nil {
 		if meters, err = strconv.Atoi(s[0]); err != nil {
 			return
 			return
 		}
 		}
+		// RFC1876 states the max value is 90000000.00.  The latter two conditions enforce it.
+		if s[0][0] < '0' || s[0][0] > '9' || meters > 90000000 || (meters == 90000000 && cmeters != 0) {
+			return
+		}
 	case 0:
 	case 0:
 		// huh?
 		// huh?
 		return 0, 0, false
 		return 0, 0, false
@@ -1227,13 +1261,10 @@ func stringToCm(token string) (e, m uint8, ok bool) {
 		e = 0
 		e = 0
 		val = cmeters
 		val = cmeters
 	}
 	}
-	for val > 10 {
+	for val >= 10 {
 		e++
 		e++
 		val /= 10
 		val /= 10
 	}
 	}
-	if e > 9 {
-		ok = false
-	}
 	m = uint8(val)
 	m = uint8(val)
 	return
 	return
 }
 }
@@ -1275,6 +1306,9 @@ func appendOrigin(name, origin string) string {
 
 
 // LOC record helper function
 // LOC record helper function
 func locCheckNorth(token string, latitude uint32) (uint32, bool) {
 func locCheckNorth(token string, latitude uint32) (uint32, bool) {
+	if latitude > 90*1000*60*60 {
+		return latitude, false
+	}
 	switch token {
 	switch token {
 	case "n", "N":
 	case "n", "N":
 		return LOC_EQUATOR + latitude, true
 		return LOC_EQUATOR + latitude, true
@@ -1286,6 +1320,9 @@ func locCheckNorth(token string, latitude uint32) (uint32, bool) {
 
 
 // LOC record helper function
 // LOC record helper function
 func locCheckEast(token string, longitude uint32) (uint32, bool) {
 func locCheckEast(token string, longitude uint32) (uint32, bool) {
+	if longitude > 180*1000*60*60 {
+		return longitude, false
+	}
 	switch token {
 	switch token {
 	case "e", "E":
 	case "e", "E":
 		return LOC_EQUATOR + longitude, true
 		return LOC_EQUATOR + longitude, true
@@ -1318,7 +1355,7 @@ func stringToNodeID(l lex) (uint64, *ParseError) {
 	if len(l.token) < 19 {
 	if len(l.token) < 19 {
 		return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
 		return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
 	}
 	}
-	// There must be three colons at fixes postitions, if not its a parse error
+	// There must be three colons at fixes positions, if not its a parse error
 	if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
 	if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
 		return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
 		return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
 	}
 	}

+ 53 - 18
vendor/github.com/miekg/dns/scan_rr.go

@@ -590,7 +590,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError {
 	// North
 	// North
 	l, _ := c.Next()
 	l, _ := c.Next()
 	i, e := strconv.ParseUint(l.token, 10, 32)
 	i, e := strconv.ParseUint(l.token, 10, 32)
-	if e != nil || l.err {
+	if e != nil || l.err || i > 90 {
 		return &ParseError{"", "bad LOC Latitude", l}
 		return &ParseError{"", "bad LOC Latitude", l}
 	}
 	}
 	rr.Latitude = 1000 * 60 * 60 * uint32(i)
 	rr.Latitude = 1000 * 60 * 60 * uint32(i)
@@ -601,7 +601,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError {
 	if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok {
 	if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok {
 		goto East
 		goto East
 	}
 	}
-	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err {
+	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 {
 		return &ParseError{"", "bad LOC Latitude minutes", l}
 		return &ParseError{"", "bad LOC Latitude minutes", l}
 	} else {
 	} else {
 		rr.Latitude += 1000 * 60 * uint32(i)
 		rr.Latitude += 1000 * 60 * uint32(i)
@@ -609,7 +609,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError {
 
 
 	c.Next() // zBlank
 	c.Next() // zBlank
 	l, _ = c.Next()
 	l, _ = c.Next()
-	if i, err := strconv.ParseFloat(l.token, 32); err != nil || l.err {
+	if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 {
 		return &ParseError{"", "bad LOC Latitude seconds", l}
 		return &ParseError{"", "bad LOC Latitude seconds", l}
 	} else {
 	} else {
 		rr.Latitude += uint32(1000 * i)
 		rr.Latitude += uint32(1000 * i)
@@ -627,7 +627,7 @@ East:
 	// East
 	// East
 	c.Next() // zBlank
 	c.Next() // zBlank
 	l, _ = c.Next()
 	l, _ = c.Next()
-	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err {
+	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 180 {
 		return &ParseError{"", "bad LOC Longitude", l}
 		return &ParseError{"", "bad LOC Longitude", l}
 	} else {
 	} else {
 		rr.Longitude = 1000 * 60 * 60 * uint32(i)
 		rr.Longitude = 1000 * 60 * 60 * uint32(i)
@@ -638,14 +638,14 @@ East:
 	if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok {
 	if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok {
 		goto Altitude
 		goto Altitude
 	}
 	}
-	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err {
+	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 {
 		return &ParseError{"", "bad LOC Longitude minutes", l}
 		return &ParseError{"", "bad LOC Longitude minutes", l}
 	} else {
 	} else {
 		rr.Longitude += 1000 * 60 * uint32(i)
 		rr.Longitude += 1000 * 60 * uint32(i)
 	}
 	}
 	c.Next() // zBlank
 	c.Next() // zBlank
 	l, _ = c.Next()
 	l, _ = c.Next()
-	if i, err := strconv.ParseFloat(l.token, 32); err != nil || l.err {
+	if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 {
 		return &ParseError{"", "bad LOC Longitude seconds", l}
 		return &ParseError{"", "bad LOC Longitude seconds", l}
 	} else {
 	} else {
 		rr.Longitude += uint32(1000 * i)
 		rr.Longitude += uint32(1000 * i)
@@ -662,13 +662,13 @@ East:
 Altitude:
 Altitude:
 	c.Next() // zBlank
 	c.Next() // zBlank
 	l, _ = c.Next()
 	l, _ = c.Next()
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad LOC Altitude", l}
 		return &ParseError{"", "bad LOC Altitude", l}
 	}
 	}
 	if l.token[len(l.token)-1] == 'M' || l.token[len(l.token)-1] == 'm' {
 	if l.token[len(l.token)-1] == 'M' || l.token[len(l.token)-1] == 'm' {
 		l.token = l.token[0 : len(l.token)-1]
 		l.token = l.token[0 : len(l.token)-1]
 	}
 	}
-	if i, err := strconv.ParseFloat(l.token, 32); err != nil {
+	if i, err := strconv.ParseFloat(l.token, 64); err != nil {
 		return &ParseError{"", "bad LOC Altitude", l}
 		return &ParseError{"", "bad LOC Altitude", l}
 	} else {
 	} else {
 		rr.Altitude = uint32(i*100.0 + 10000000.0 + 0.5)
 		rr.Altitude = uint32(i*100.0 + 10000000.0 + 0.5)
@@ -722,7 +722,7 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError {
 
 
 	c.Next()        // zBlank
 	c.Next()        // zBlank
 	l, _ = c.Next() // zString
 	l, _ = c.Next() // zString
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad HIP Hit", l}
 		return &ParseError{"", "bad HIP Hit", l}
 	}
 	}
 	rr.Hit = l.token // This can not contain spaces, see RFC 5205 Section 6.
 	rr.Hit = l.token // This can not contain spaces, see RFC 5205 Section 6.
@@ -730,11 +730,15 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError {
 
 
 	c.Next()        // zBlank
 	c.Next()        // zBlank
 	l, _ = c.Next() // zString
 	l, _ = c.Next() // zString
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad HIP PublicKey", l}
 		return &ParseError{"", "bad HIP PublicKey", l}
 	}
 	}
 	rr.PublicKey = l.token // This cannot contain spaces
 	rr.PublicKey = l.token // This cannot contain spaces
-	rr.PublicKeyLength = uint16(base64.StdEncoding.DecodedLen(len(rr.PublicKey)))
+	decodedPK, decodedPKerr := base64.StdEncoding.DecodeString(rr.PublicKey)
+	if decodedPKerr != nil {
+		return &ParseError{"", "bad HIP PublicKey", l}
+	}
+	rr.PublicKeyLength = uint16(len(decodedPK))
 
 
 	// RendezvousServers (if any)
 	// RendezvousServers (if any)
 	l, _ = c.Next()
 	l, _ = c.Next()
@@ -846,6 +850,38 @@ func (rr *CSYNC) parse(c *zlexer, o string) *ParseError {
 	return nil
 	return nil
 }
 }
 
 
+func (rr *ZONEMD) parse(c *zlexer, o string) *ParseError {
+	l, _ := c.Next()
+	i, e := strconv.ParseUint(l.token, 10, 32)
+	if e != nil || l.err {
+		return &ParseError{"", "bad ZONEMD Serial", l}
+	}
+	rr.Serial = uint32(i)
+
+	c.Next() // zBlank
+	l, _ = c.Next()
+	i, e1 := strconv.ParseUint(l.token, 10, 8)
+	if e1 != nil || l.err {
+		return &ParseError{"", "bad ZONEMD Scheme", l}
+	}
+	rr.Scheme = uint8(i)
+
+	c.Next() // zBlank
+	l, _ = c.Next()
+	i, err := strconv.ParseUint(l.token, 10, 8)
+	if err != nil || l.err {
+		return &ParseError{"", "bad ZONEMD Hash Algorithm", l}
+	}
+	rr.Hash = uint8(i)
+
+	s, e2 := endingToString(c, "bad ZONEMD Digest")
+	if e2 != nil {
+		return e2
+	}
+	rr.Digest = s
+	return nil
+}
+
 func (rr *SIG) parse(c *zlexer, o string) *ParseError { return rr.RRSIG.parse(c, o) }
 func (rr *SIG) parse(c *zlexer, o string) *ParseError { return rr.RRSIG.parse(c, o) }
 
 
 func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
 func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
@@ -893,8 +929,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
 	l, _ = c.Next()
 	l, _ = c.Next()
 	if i, err := StringToTime(l.token); err != nil {
 	if i, err := StringToTime(l.token); err != nil {
 		// Try to see if all numeric and use it as epoch
 		// Try to see if all numeric and use it as epoch
-		if i, err := strconv.ParseInt(l.token, 10, 64); err == nil {
-			// TODO(miek): error out on > MAX_UINT32, same below
+		if i, err := strconv.ParseUint(l.token, 10, 32); err == nil {
 			rr.Expiration = uint32(i)
 			rr.Expiration = uint32(i)
 		} else {
 		} else {
 			return &ParseError{"", "bad RRSIG Expiration", l}
 			return &ParseError{"", "bad RRSIG Expiration", l}
@@ -906,7 +941,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
 	c.Next() // zBlank
 	c.Next() // zBlank
 	l, _ = c.Next()
 	l, _ = c.Next()
 	if i, err := StringToTime(l.token); err != nil {
 	if i, err := StringToTime(l.token); err != nil {
-		if i, err := strconv.ParseInt(l.token, 10, 64); err == nil {
+		if i, err := strconv.ParseUint(l.token, 10, 32); err == nil {
 			rr.Inception = uint32(i)
 			rr.Inception = uint32(i)
 		} else {
 		} else {
 			return &ParseError{"", "bad RRSIG Inception", l}
 			return &ParseError{"", "bad RRSIG Inception", l}
@@ -998,7 +1033,7 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError {
 	rr.Iterations = uint16(i)
 	rr.Iterations = uint16(i)
 	c.Next()
 	c.Next()
 	l, _ = c.Next()
 	l, _ = c.Next()
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad NSEC3 Salt", l}
 		return &ParseError{"", "bad NSEC3 Salt", l}
 	}
 	}
 	if l.token != "-" {
 	if l.token != "-" {
@@ -1008,7 +1043,7 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError {
 
 
 	c.Next()
 	c.Next()
 	l, _ = c.Next()
 	l, _ = c.Next()
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad NSEC3 NextDomain", l}
 		return &ParseError{"", "bad NSEC3 NextDomain", l}
 	}
 	}
 	rr.HashLength = 20 // Fix for NSEC3 (sha1 160 bits)
 	rr.HashLength = 20 // Fix for NSEC3 (sha1 160 bits)
@@ -1388,7 +1423,7 @@ func (rr *RFC3597) parse(c *zlexer, o string) *ParseError {
 
 
 	c.Next() // zBlank
 	c.Next() // zBlank
 	l, _ = c.Next()
 	l, _ = c.Next()
-	rdlength, e := strconv.Atoi(l.token)
+	rdlength, e := strconv.ParseUint(l.token, 10, 16)
 	if e != nil || l.err {
 	if e != nil || l.err {
 		return &ParseError{"", "bad RFC3597 Rdata ", l}
 		return &ParseError{"", "bad RFC3597 Rdata ", l}
 	}
 	}
@@ -1397,7 +1432,7 @@ func (rr *RFC3597) parse(c *zlexer, o string) *ParseError {
 	if e1 != nil {
 	if e1 != nil {
 		return e1
 		return e1
 	}
 	}
-	if rdlength*2 != len(s) {
+	if int(rdlength)*2 != len(s) {
 		return &ParseError{"", "bad RFC3597 Rdata", l}
 		return &ParseError{"", "bad RFC3597 Rdata", l}
 	}
 	}
 	rr.Rdata = s
 	rr.Rdata = s

+ 2 - 2
vendor/github.com/miekg/dns/serve_mux.go

@@ -91,7 +91,7 @@ func (mux *ServeMux) HandleRemove(pattern string) {
 // are redirected to the parent zone (if that is also registered),
 // are redirected to the parent zone (if that is also registered),
 // otherwise the child gets the query.
 // otherwise the child gets the query.
 //
 //
-// If no handler is found, or there is no question, a standard SERVFAIL
+// If no handler is found, or there is no question, a standard REFUSED
 // message is returned
 // message is returned
 func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
 func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
 	var h Handler
 	var h Handler
@@ -102,7 +102,7 @@ func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
 	if h != nil {
 	if h != nil {
 		h.ServeDNS(w, req)
 		h.ServeDNS(w, req)
 	} else {
 	} else {
-		HandleFailed(w, req)
+		handleRefused(w, req)
 	}
 	}
 }
 }
 
 

+ 92 - 28
vendor/github.com/miekg/dns/server.go

@@ -72,13 +72,22 @@ type response struct {
 	tsigStatus     error
 	tsigStatus     error
 	tsigRequestMAC string
 	tsigRequestMAC string
 	tsigSecret     map[string]string // the tsig secrets
 	tsigSecret     map[string]string // the tsig secrets
-	udp            *net.UDPConn      // i/o connection if UDP was used
+	udp            net.PacketConn    // i/o connection if UDP was used
 	tcp            net.Conn          // i/o connection if TCP was used
 	tcp            net.Conn          // i/o connection if TCP was used
 	udpSession     *SessionUDP       // oob data to get egress interface right
 	udpSession     *SessionUDP       // oob data to get egress interface right
+	pcSession      net.Addr          // address to use when writing to a generic net.PacketConn
 	writer         Writer            // writer to output the raw DNS bits
 	writer         Writer            // writer to output the raw DNS bits
 }
 }
 
 
+// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
+func handleRefused(w ResponseWriter, r *Msg) {
+	m := new(Msg)
+	m.SetRcode(r, RcodeRefused)
+	w.WriteMsg(m)
+}
+
 // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
 // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
+// Deprecated: This function is going away.
 func HandleFailed(w ResponseWriter, r *Msg) {
 func HandleFailed(w ResponseWriter, r *Msg) {
 	m := new(Msg)
 	m := new(Msg)
 	m.SetRcode(r, RcodeServerFailure)
 	m.SetRcode(r, RcodeServerFailure)
@@ -139,12 +148,24 @@ type Reader interface {
 	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
 	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
 }
 }
 
 
-// defaultReader is an adapter for the Server struct that implements the Reader interface
-// using the readTCP and readUDP func of the embedded Server.
+// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
+type PacketConnReader interface {
+	Reader
+
+	// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
+	// alter connection properties, for example the read-deadline.
+	ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
+}
+
+// defaultReader is an adapter for the Server struct that implements the Reader and
+// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
+// of the embedded Server.
 type defaultReader struct {
 type defaultReader struct {
 	*Server
 	*Server
 }
 }
 
 
+var _ PacketConnReader = defaultReader{}
+
 func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
 func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
 	return dr.readTCP(conn, timeout)
 	return dr.readTCP(conn, timeout)
 }
 }
@@ -153,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt
 	return dr.readUDP(conn, timeout)
 	return dr.readUDP(conn, timeout)
 }
 }
 
 
+func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+	return dr.readPacketConn(conn, timeout)
+}
+
 // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
 // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
 // Implementations should never return a nil Reader.
 // Implementations should never return a nil Reader.
+// Readers should also implement the optional PacketConnReader interface.
+// PacketConnReader is required to use a generic net.PacketConn.
 type DecorateReader func(Reader) Reader
 type DecorateReader func(Reader) Reader
 
 
 // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
 // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
@@ -294,6 +321,7 @@ func (srv *Server) ListenAndServe() error {
 		}
 		}
 		u := l.(*net.UDPConn)
 		u := l.(*net.UDPConn)
 		if e := setUDPSocketOptions(u); e != nil {
 		if e := setUDPSocketOptions(u); e != nil {
+			u.Close()
 			return e
 			return e
 		}
 		}
 		srv.PacketConn = l
 		srv.PacketConn = l
@@ -317,24 +345,22 @@ func (srv *Server) ActivateAndServe() error {
 
 
 	srv.init()
 	srv.init()
 
 
-	pConn := srv.PacketConn
-	l := srv.Listener
-	if pConn != nil {
+	if srv.PacketConn != nil {
 		// Check PacketConn interface's type is valid and value
 		// Check PacketConn interface's type is valid and value
 		// is not nil
 		// is not nil
-		if t, ok := pConn.(*net.UDPConn); ok && t != nil {
+		if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
 			if e := setUDPSocketOptions(t); e != nil {
 			if e := setUDPSocketOptions(t); e != nil {
 				return e
 				return e
 			}
 			}
-			srv.started = true
-			unlock()
-			return srv.serveUDP(t)
 		}
 		}
+		srv.started = true
+		unlock()
+		return srv.serveUDP(srv.PacketConn)
 	}
 	}
-	if l != nil {
+	if srv.Listener != nil {
 		srv.started = true
 		srv.started = true
 		unlock()
 		unlock()
-		return srv.serveTCP(l)
+		return srv.serveTCP(srv.Listener)
 	}
 	}
 	return &Error{err: "bad listeners"}
 	return &Error{err: "bad listeners"}
 }
 }
@@ -438,18 +464,24 @@ func (srv *Server) serveTCP(l net.Listener) error {
 }
 }
 
 
 // serveUDP starts a UDP listener for the server.
 // serveUDP starts a UDP listener for the server.
-func (srv *Server) serveUDP(l *net.UDPConn) error {
+func (srv *Server) serveUDP(l net.PacketConn) error {
 	defer l.Close()
 	defer l.Close()
 
 
-	if srv.NotifyStartedFunc != nil {
-		srv.NotifyStartedFunc()
-	}
-
 	reader := Reader(defaultReader{srv})
 	reader := Reader(defaultReader{srv})
 	if srv.DecorateReader != nil {
 	if srv.DecorateReader != nil {
 		reader = srv.DecorateReader(reader)
 		reader = srv.DecorateReader(reader)
 	}
 	}
 
 
+	lUDP, isUDP := l.(*net.UDPConn)
+	readerPC, canPacketConn := reader.(PacketConnReader)
+	if !isUDP && !canPacketConn {
+		return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
+	}
+
+	if srv.NotifyStartedFunc != nil {
+		srv.NotifyStartedFunc()
+	}
+
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	defer func() {
 	defer func() {
 		wg.Wait()
 		wg.Wait()
@@ -459,7 +491,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
 	rtimeout := srv.getReadTimeout()
 	rtimeout := srv.getReadTimeout()
 	// deadline is not used here
 	// deadline is not used here
 	for srv.isStarted() {
 	for srv.isStarted() {
-		m, s, err := reader.ReadUDP(l, rtimeout)
+		var (
+			m    []byte
+			sPC  net.Addr
+			sUDP *SessionUDP
+			err  error
+		)
+		if isUDP {
+			m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
+		} else {
+			m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
+		}
 		if err != nil {
 		if err != nil {
 			if !srv.isStarted() {
 			if !srv.isStarted() {
 				return nil
 				return nil
@@ -476,7 +518,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
 			continue
 			continue
 		}
 		}
 		wg.Add(1)
 		wg.Add(1)
-		go srv.serveUDPPacket(&wg, m, l, s)
+		go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
 	}
 	}
 
 
 	return nil
 	return nil
@@ -538,8 +580,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
 }
 }
 
 
 // Serve a new UDP request.
 // Serve a new UDP request.
-func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
-	w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
+func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
+	w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
 	if srv.DecorateWriter != nil {
 	if srv.DecorateWriter != nil {
 		w.writer = srv.DecorateWriter(w)
 		w.writer = srv.DecorateWriter(w)
 	} else {
 	} else {
@@ -651,6 +693,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
 	return m, s, nil
 	return m, s, nil
 }
 }
 
 
+func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+	srv.lock.RLock()
+	if srv.started {
+		// See the comment in readTCP above.
+		conn.SetReadDeadline(time.Now().Add(timeout))
+	}
+	srv.lock.RUnlock()
+
+	m := srv.udpPool.Get().([]byte)
+	n, addr, err := conn.ReadFrom(m)
+	if err != nil {
+		srv.udpPool.Put(m)
+		return nil, nil, err
+	}
+	m = m[:n]
+	return m, addr, nil
+}
+
 // WriteMsg implements the ResponseWriter.WriteMsg method.
 // WriteMsg implements the ResponseWriter.WriteMsg method.
 func (w *response) WriteMsg(m *Msg) (err error) {
 func (w *response) WriteMsg(m *Msg) (err error) {
 	if w.closed {
 	if w.closed {
@@ -684,17 +744,19 @@ func (w *response) Write(m []byte) (int, error) {
 
 
 	switch {
 	switch {
 	case w.udp != nil:
 	case w.udp != nil:
-		return WriteToSessionUDP(w.udp, m, w.udpSession)
+		if u, ok := w.udp.(*net.UDPConn); ok {
+			return WriteToSessionUDP(u, m, w.udpSession)
+		}
+		return w.udp.WriteTo(m, w.pcSession)
 	case w.tcp != nil:
 	case w.tcp != nil:
 		if len(m) > MaxMsgSize {
 		if len(m) > MaxMsgSize {
 			return 0, &Error{err: "message too large"}
 			return 0, &Error{err: "message too large"}
 		}
 		}
 
 
-		l := make([]byte, 2)
-		binary.BigEndian.PutUint16(l, uint16(len(m)))
-
-		n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
-		return int(n), err
+		msg := make([]byte, 2+len(m))
+		binary.BigEndian.PutUint16(msg, uint16(len(m)))
+		copy(msg[2:], m)
+		return w.tcp.Write(msg)
 	default:
 	default:
 		panic("dns: internal error: udp and tcp both nil")
 		panic("dns: internal error: udp and tcp both nil")
 	}
 	}
@@ -717,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr {
 	switch {
 	switch {
 	case w.udpSession != nil:
 	case w.udpSession != nil:
 		return w.udpSession.RemoteAddr()
 		return w.udpSession.RemoteAddr()
+	case w.pcSession != nil:
+		return w.pcSession
 	case w.tcp != nil:
 	case w.tcp != nil:
 		return w.tcp.RemoteAddr()
 		return w.tcp.RemoteAddr()
 	default:
 	default:
-		panic("dns: internal error: udpSession and tcp both nil")
+		panic("dns: internal error: udpSession, pcSession and tcp are all nil")
 	}
 	}
 }
 }
 
 

+ 3 - 15
vendor/github.com/miekg/dns/sig0.go

@@ -2,7 +2,6 @@ package dns
 
 
 import (
 import (
 	"crypto"
 	"crypto"
-	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/ecdsa"
 	"crypto/rsa"
 	"crypto/rsa"
 	"encoding/binary"
 	"encoding/binary"
@@ -18,7 +17,7 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
 	if k == nil {
 	if k == nil {
 		return nil, ErrPrivKey
 		return nil, ErrPrivKey
 	}
 	}
-	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
+	if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 {
 		return nil, ErrKey
 		return nil, ErrKey
 	}
 	}
 
 
@@ -79,13 +78,13 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
 	if k == nil {
 	if k == nil {
 		return ErrKey
 		return ErrKey
 	}
 	}
-	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
+	if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 {
 		return ErrKey
 		return ErrKey
 	}
 	}
 
 
 	var hash crypto.Hash
 	var hash crypto.Hash
 	switch rr.Algorithm {
 	switch rr.Algorithm {
-	case DSA, RSASHA1:
+	case RSASHA1:
 		hash = crypto.SHA1
 		hash = crypto.SHA1
 	case RSASHA256, ECDSAP256SHA256:
 	case RSASHA256, ECDSAP256SHA256:
 		hash = crypto.SHA256
 		hash = crypto.SHA256
@@ -178,17 +177,6 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
 	hashed := hasher.Sum(nil)
 	hashed := hasher.Sum(nil)
 	sig := buf[sigend:]
 	sig := buf[sigend:]
 	switch k.Algorithm {
 	switch k.Algorithm {
-	case DSA:
-		pk := k.publicKeyDSA()
-		sig = sig[1:]
-		r := new(big.Int).SetBytes(sig[:len(sig)/2])
-		s := new(big.Int).SetBytes(sig[len(sig)/2:])
-		if pk != nil {
-			if dsa.Verify(pk, hashed, r, s) {
-				return nil
-			}
-			return ErrSig
-		}
 	case RSASHA1, RSASHA256, RSASHA512:
 	case RSASHA1, RSASHA256, RSASHA512:
 		pk := k.publicKeyRSA()
 		pk := k.publicKeyRSA()
 		if pk != nil {
 		if pk != nil {

+ 755 - 0
vendor/github.com/miekg/dns/svcb.go

@@ -0,0 +1,755 @@
+package dns
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"net"
+	"sort"
+	"strconv"
+	"strings"
+)
+
+// SVCBKey is the type of the keys used in the SVCB RR.
+type SVCBKey uint16
+
+// Keys defined in draft-ietf-dnsop-svcb-https-01 Section 12.3.2.
+const (
+	SVCB_MANDATORY       SVCBKey = 0
+	SVCB_ALPN            SVCBKey = 1
+	SVCB_NO_DEFAULT_ALPN SVCBKey = 2
+	SVCB_PORT            SVCBKey = 3
+	SVCB_IPV4HINT        SVCBKey = 4
+	SVCB_ECHCONFIG       SVCBKey = 5
+	SVCB_IPV6HINT        SVCBKey = 6
+	svcb_RESERVED        SVCBKey = 65535
+)
+
+var svcbKeyToStringMap = map[SVCBKey]string{
+	SVCB_MANDATORY:       "mandatory",
+	SVCB_ALPN:            "alpn",
+	SVCB_NO_DEFAULT_ALPN: "no-default-alpn",
+	SVCB_PORT:            "port",
+	SVCB_IPV4HINT:        "ipv4hint",
+	SVCB_ECHCONFIG:       "echconfig",
+	SVCB_IPV6HINT:        "ipv6hint",
+}
+
+var svcbStringToKeyMap = reverseSVCBKeyMap(svcbKeyToStringMap)
+
+func reverseSVCBKeyMap(m map[SVCBKey]string) map[string]SVCBKey {
+	n := make(map[string]SVCBKey, len(m))
+	for u, s := range m {
+		n[s] = u
+	}
+	return n
+}
+
+// String takes the numerical code of an SVCB key and returns its name.
+// Returns an empty string for reserved keys.
+// Accepts unassigned keys as well as experimental/private keys.
+func (key SVCBKey) String() string {
+	if x := svcbKeyToStringMap[key]; x != "" {
+		return x
+	}
+	if key == svcb_RESERVED {
+		return ""
+	}
+	return "key" + strconv.FormatUint(uint64(key), 10)
+}
+
+// svcbStringToKey returns the numerical code of an SVCB key.
+// Returns svcb_RESERVED for reserved/invalid keys.
+// Accepts unassigned keys as well as experimental/private keys.
+func svcbStringToKey(s string) SVCBKey {
+	if strings.HasPrefix(s, "key") {
+		a, err := strconv.ParseUint(s[3:], 10, 16)
+		// no leading zeros
+		// key shouldn't be registered
+		if err != nil || a == 65535 || s[3] == '0' || svcbKeyToStringMap[SVCBKey(a)] != "" {
+			return svcb_RESERVED
+		}
+		return SVCBKey(a)
+	}
+	if key, ok := svcbStringToKeyMap[s]; ok {
+		return key
+	}
+	return svcb_RESERVED
+}
+
+func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
+	l, _ := c.Next()
+	i, e := strconv.ParseUint(l.token, 10, 16)
+	if e != nil || l.err {
+		return &ParseError{l.token, "bad SVCB priority", l}
+	}
+	rr.Priority = uint16(i)
+
+	c.Next()        // zBlank
+	l, _ = c.Next() // zString
+	rr.Target = l.token
+
+	name, nameOk := toAbsoluteName(l.token, o)
+	if l.err || !nameOk {
+		return &ParseError{l.token, "bad SVCB Target", l}
+	}
+	rr.Target = name
+
+	// Values (if any)
+	l, _ = c.Next()
+	var xs []SVCBKeyValue
+	// Helps require whitespace between pairs.
+	// Prevents key1000="a"key1001=...
+	canHaveNextKey := true
+	for l.value != zNewline && l.value != zEOF {
+		switch l.value {
+		case zString:
+			if !canHaveNextKey {
+				// The key we can now read was probably meant to be
+				// a part of the last value.
+				return &ParseError{l.token, "bad SVCB value quotation", l}
+			}
+
+			// In key=value pairs, value does not have to be quoted unless value
+			// contains whitespace. And keys don't need to have values.
+			// Similarly, keys with an equality signs after them don't need values.
+			// l.token includes at least up to the first equality sign.
+			idx := strings.IndexByte(l.token, '=')
+			var key, value string
+			if idx < 0 {
+				// Key with no value and no equality sign
+				key = l.token
+			} else if idx == 0 {
+				return &ParseError{l.token, "bad SVCB key", l}
+			} else {
+				key, value = l.token[:idx], l.token[idx+1:]
+
+				if value == "" {
+					// We have a key and an equality sign. Maybe we have nothing
+					// after "=" or we have a double quote.
+					l, _ = c.Next()
+					if l.value == zQuote {
+						// Only needed when value ends with double quotes.
+						// Any value starting with zQuote ends with it.
+						canHaveNextKey = false
+
+						l, _ = c.Next()
+						switch l.value {
+						case zString:
+							// We have a value in double quotes.
+							value = l.token
+							l, _ = c.Next()
+							if l.value != zQuote {
+								return &ParseError{l.token, "SVCB unterminated value", l}
+							}
+						case zQuote:
+							// There's nothing in double quotes.
+						default:
+							return &ParseError{l.token, "bad SVCB value", l}
+						}
+					}
+				}
+			}
+			kv := makeSVCBKeyValue(svcbStringToKey(key))
+			if kv == nil {
+				return &ParseError{l.token, "bad SVCB key", l}
+			}
+			if err := kv.parse(value); err != nil {
+				return &ParseError{l.token, err.Error(), l}
+			}
+			xs = append(xs, kv)
+		case zQuote:
+			return &ParseError{l.token, "SVCB key can't contain double quotes", l}
+		case zBlank:
+			canHaveNextKey = true
+		default:
+			return &ParseError{l.token, "bad SVCB values", l}
+		}
+		l, _ = c.Next()
+	}
+	rr.Value = xs
+	if rr.Priority == 0 && len(xs) > 0 {
+		return &ParseError{l.token, "SVCB aliasform can't have values", l}
+	}
+	return nil
+}
+
+// makeSVCBKeyValue returns an SVCBKeyValue struct with the key or nil for reserved keys.
+func makeSVCBKeyValue(key SVCBKey) SVCBKeyValue {
+	switch key {
+	case SVCB_MANDATORY:
+		return new(SVCBMandatory)
+	case SVCB_ALPN:
+		return new(SVCBAlpn)
+	case SVCB_NO_DEFAULT_ALPN:
+		return new(SVCBNoDefaultAlpn)
+	case SVCB_PORT:
+		return new(SVCBPort)
+	case SVCB_IPV4HINT:
+		return new(SVCBIPv4Hint)
+	case SVCB_ECHCONFIG:
+		return new(SVCBECHConfig)
+	case SVCB_IPV6HINT:
+		return new(SVCBIPv6Hint)
+	case svcb_RESERVED:
+		return nil
+	default:
+		e := new(SVCBLocal)
+		e.KeyCode = key
+		return e
+	}
+}
+
+// SVCB RR. See RFC xxxx (https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-01).
+type SVCB struct {
+	Hdr      RR_Header
+	Priority uint16
+	Target   string         `dns:"domain-name"`
+	Value    []SVCBKeyValue `dns:"pairs"` // Value must be empty if Priority is zero.
+}
+
+// HTTPS RR. Everything valid for SVCB applies to HTTPS as well.
+// Except that the HTTPS record is intended for use with the HTTP and HTTPS protocols.
+type HTTPS struct {
+	SVCB
+}
+
+func (rr *HTTPS) String() string {
+	return rr.SVCB.String()
+}
+
+func (rr *HTTPS) parse(c *zlexer, o string) *ParseError {
+	return rr.SVCB.parse(c, o)
+}
+
+// SVCBKeyValue defines a key=value pair for the SVCB RR type.
+// An SVCB RR can have multiple SVCBKeyValues appended to it.
+type SVCBKeyValue interface {
+	Key() SVCBKey          // Key returns the numerical key code.
+	pack() ([]byte, error) // pack returns the encoded value.
+	unpack([]byte) error   // unpack sets the value.
+	String() string        // String returns the string representation of the value.
+	parse(string) error    // parse sets the value to the given string representation of the value.
+	copy() SVCBKeyValue    // copy returns a deep-copy of the pair.
+	len() int              // len returns the length of value in the wire format.
+}
+
+// SVCBMandatory pair adds to required keys that must be interpreted for the RR
+// to be functional.
+// Basic use pattern for creating a mandatory option:
+//
+//	s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
+//	e := new(dns.SVCBMandatory)
+//	e.Code = []uint16{65403}
+//	s.Value = append(s.Value, e)
+type SVCBMandatory struct {
+	Code []SVCBKey // Must not include mandatory
+}
+
+func (*SVCBMandatory) Key() SVCBKey { return SVCB_MANDATORY }
+
+func (s *SVCBMandatory) String() string {
+	str := make([]string, len(s.Code))
+	for i, e := range s.Code {
+		str[i] = e.String()
+	}
+	return strings.Join(str, ",")
+}
+
+func (s *SVCBMandatory) pack() ([]byte, error) {
+	codes := append([]SVCBKey(nil), s.Code...)
+	sort.Slice(codes, func(i, j int) bool {
+		return codes[i] < codes[j]
+	})
+	b := make([]byte, 2*len(codes))
+	for i, e := range codes {
+		binary.BigEndian.PutUint16(b[2*i:], uint16(e))
+	}
+	return b, nil
+}
+
+func (s *SVCBMandatory) unpack(b []byte) error {
+	if len(b)%2 != 0 {
+		return errors.New("dns: svcbmandatory: value length is not a multiple of 2")
+	}
+	codes := make([]SVCBKey, 0, len(b)/2)
+	for i := 0; i < len(b); i += 2 {
+		// We assume strictly increasing order.
+		codes = append(codes, SVCBKey(binary.BigEndian.Uint16(b[i:])))
+	}
+	s.Code = codes
+	return nil
+}
+
+func (s *SVCBMandatory) parse(b string) error {
+	str := strings.Split(b, ",")
+	codes := make([]SVCBKey, 0, len(str))
+	for _, e := range str {
+		codes = append(codes, svcbStringToKey(e))
+	}
+	s.Code = codes
+	return nil
+}
+
+func (s *SVCBMandatory) len() int {
+	return 2 * len(s.Code)
+}
+
+func (s *SVCBMandatory) copy() SVCBKeyValue {
+	return &SVCBMandatory{
+		append([]SVCBKey(nil), s.Code...),
+	}
+}
+
+// SVCBAlpn pair is used to list supported connection protocols.
+// Protocol ids can be found at:
+// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids
+// Basic use pattern for creating an alpn option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBAlpn)
+//	e.Alpn = []string{"h2", "http/1.1"}
+//	h.Value = append(o.Value, e)
+type SVCBAlpn struct {
+	Alpn []string
+}
+
+func (*SVCBAlpn) Key() SVCBKey     { return SVCB_ALPN }
+func (s *SVCBAlpn) String() string { return strings.Join(s.Alpn, ",") }
+
+func (s *SVCBAlpn) pack() ([]byte, error) {
+	// Liberally estimate the size of an alpn as 10 octets
+	b := make([]byte, 0, 10*len(s.Alpn))
+	for _, e := range s.Alpn {
+		if e == "" {
+			return nil, errors.New("dns: svcbalpn: empty alpn-id")
+		}
+		if len(e) > 255 {
+			return nil, errors.New("dns: svcbalpn: alpn-id too long")
+		}
+		b = append(b, byte(len(e)))
+		b = append(b, e...)
+	}
+	return b, nil
+}
+
+func (s *SVCBAlpn) unpack(b []byte) error {
+	// Estimate the size of the smallest alpn as 4 bytes
+	alpn := make([]string, 0, len(b)/4)
+	for i := 0; i < len(b); {
+		length := int(b[i])
+		i++
+		if i+length > len(b) {
+			return errors.New("dns: svcbalpn: alpn array overflowing")
+		}
+		alpn = append(alpn, string(b[i:i+length]))
+		i += length
+	}
+	s.Alpn = alpn
+	return nil
+}
+
+func (s *SVCBAlpn) parse(b string) error {
+	s.Alpn = strings.Split(b, ",")
+	return nil
+}
+
+func (s *SVCBAlpn) len() int {
+	var l int
+	for _, e := range s.Alpn {
+		l += 1 + len(e)
+	}
+	return l
+}
+
+func (s *SVCBAlpn) copy() SVCBKeyValue {
+	return &SVCBAlpn{
+		append([]string(nil), s.Alpn...),
+	}
+}
+
+// SVCBNoDefaultAlpn pair signifies no support for default connection protocols.
+// Basic use pattern for creating a no-default-alpn option:
+//
+//	s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
+//	e := new(dns.SVCBNoDefaultAlpn)
+//	s.Value = append(s.Value, e)
+type SVCBNoDefaultAlpn struct{}
+
+func (*SVCBNoDefaultAlpn) Key() SVCBKey          { return SVCB_NO_DEFAULT_ALPN }
+func (*SVCBNoDefaultAlpn) copy() SVCBKeyValue    { return &SVCBNoDefaultAlpn{} }
+func (*SVCBNoDefaultAlpn) pack() ([]byte, error) { return []byte{}, nil }
+func (*SVCBNoDefaultAlpn) String() string        { return "" }
+func (*SVCBNoDefaultAlpn) len() int              { return 0 }
+
+func (*SVCBNoDefaultAlpn) unpack(b []byte) error {
+	if len(b) != 0 {
+		return errors.New("dns: svcbnodefaultalpn: no_default_alpn must have no value")
+	}
+	return nil
+}
+
+func (*SVCBNoDefaultAlpn) parse(b string) error {
+	if b != "" {
+		return errors.New("dns: svcbnodefaultalpn: no_default_alpn must have no value")
+	}
+	return nil
+}
+
+// SVCBPort pair defines the port for connection.
+// Basic use pattern for creating a port option:
+//
+//	s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
+//	e := new(dns.SVCBPort)
+//	e.Port = 80
+//	s.Value = append(s.Value, e)
+type SVCBPort struct {
+	Port uint16
+}
+
+func (*SVCBPort) Key() SVCBKey         { return SVCB_PORT }
+func (*SVCBPort) len() int             { return 2 }
+func (s *SVCBPort) String() string     { return strconv.FormatUint(uint64(s.Port), 10) }
+func (s *SVCBPort) copy() SVCBKeyValue { return &SVCBPort{s.Port} }
+
+func (s *SVCBPort) unpack(b []byte) error {
+	if len(b) != 2 {
+		return errors.New("dns: svcbport: port length is not exactly 2 octets")
+	}
+	s.Port = binary.BigEndian.Uint16(b)
+	return nil
+}
+
+func (s *SVCBPort) pack() ([]byte, error) {
+	b := make([]byte, 2)
+	binary.BigEndian.PutUint16(b, s.Port)
+	return b, nil
+}
+
+func (s *SVCBPort) parse(b string) error {
+	port, err := strconv.ParseUint(b, 10, 16)
+	if err != nil {
+		return errors.New("dns: svcbport: port out of range")
+	}
+	s.Port = uint16(port)
+	return nil
+}
+
+// SVCBIPv4Hint pair suggests an IPv4 address which may be used to open connections
+// if A and AAAA record responses for SVCB's Target domain haven't been received.
+// In that case, optionally, A and AAAA requests can be made, after which the connection
+// to the hinted IP address may be terminated and a new connection may be opened.
+// Basic use pattern for creating an ipv4hint option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBIPv4Hint)
+//	e.Hint = []net.IP{net.IPv4(1,1,1,1).To4()}
+//
+//  Or
+//
+//	e.Hint = []net.IP{net.ParseIP("1.1.1.1").To4()}
+//	h.Value = append(h.Value, e)
+type SVCBIPv4Hint struct {
+	Hint []net.IP
+}
+
+func (*SVCBIPv4Hint) Key() SVCBKey { return SVCB_IPV4HINT }
+func (s *SVCBIPv4Hint) len() int   { return 4 * len(s.Hint) }
+
+func (s *SVCBIPv4Hint) pack() ([]byte, error) {
+	b := make([]byte, 0, 4*len(s.Hint))
+	for _, e := range s.Hint {
+		x := e.To4()
+		if x == nil {
+			return nil, errors.New("dns: svcbipv4hint: expected ipv4, hint is ipv6")
+		}
+		b = append(b, x...)
+	}
+	return b, nil
+}
+
+func (s *SVCBIPv4Hint) unpack(b []byte) error {
+	if len(b) == 0 || len(b)%4 != 0 {
+		return errors.New("dns: svcbipv4hint: ipv4 address byte array length is not a multiple of 4")
+	}
+	x := make([]net.IP, 0, len(b)/4)
+	for i := 0; i < len(b); i += 4 {
+		x = append(x, net.IP(b[i:i+4]))
+	}
+	s.Hint = x
+	return nil
+}
+
+func (s *SVCBIPv4Hint) String() string {
+	str := make([]string, len(s.Hint))
+	for i, e := range s.Hint {
+		x := e.To4()
+		if x == nil {
+			return "<nil>"
+		}
+		str[i] = x.String()
+	}
+	return strings.Join(str, ",")
+}
+
+func (s *SVCBIPv4Hint) parse(b string) error {
+	if strings.Contains(b, ":") {
+		return errors.New("dns: svcbipv4hint: expected ipv4, got ipv6")
+	}
+	str := strings.Split(b, ",")
+	dst := make([]net.IP, len(str))
+	for i, e := range str {
+		ip := net.ParseIP(e).To4()
+		if ip == nil {
+			return errors.New("dns: svcbipv4hint: bad ip")
+		}
+		dst[i] = ip
+	}
+	s.Hint = dst
+	return nil
+}
+
+func (s *SVCBIPv4Hint) copy() SVCBKeyValue {
+	hint := make([]net.IP, len(s.Hint))
+	for i, ip := range s.Hint {
+		hint[i] = copyIP(ip)
+	}
+
+	return &SVCBIPv4Hint{
+		Hint: hint,
+	}
+}
+
+// SVCBECHConfig pair contains the ECHConfig structure defined in draft-ietf-tls-esni [RFC xxxx].
+// Basic use pattern for creating an echconfig option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBECHConfig)
+//	e.ECH = []byte{0xfe, 0x08, ...}
+//	h.Value = append(h.Value, e)
+type SVCBECHConfig struct {
+	ECH []byte
+}
+
+func (*SVCBECHConfig) Key() SVCBKey     { return SVCB_ECHCONFIG }
+func (s *SVCBECHConfig) String() string { return toBase64(s.ECH) }
+func (s *SVCBECHConfig) len() int       { return len(s.ECH) }
+
+func (s *SVCBECHConfig) pack() ([]byte, error) {
+	return append([]byte(nil), s.ECH...), nil
+}
+
+func (s *SVCBECHConfig) copy() SVCBKeyValue {
+	return &SVCBECHConfig{
+		append([]byte(nil), s.ECH...),
+	}
+}
+
+func (s *SVCBECHConfig) unpack(b []byte) error {
+	s.ECH = append([]byte(nil), b...)
+	return nil
+}
+func (s *SVCBECHConfig) parse(b string) error {
+	x, err := fromBase64([]byte(b))
+	if err != nil {
+		return errors.New("dns: svcbechconfig: bad base64 echconfig")
+	}
+	s.ECH = x
+	return nil
+}
+
+// SVCBIPv6Hint pair suggests an IPv6 address which may be used to open connections
+// if A and AAAA record responses for SVCB's Target domain haven't been received.
+// In that case, optionally, A and AAAA requests can be made, after which the
+// connection to the hinted IP address may be terminated and a new connection may be opened.
+// Basic use pattern for creating an ipv6hint option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBIPv6Hint)
+//	e.Hint = []net.IP{net.ParseIP("2001:db8::1")}
+//	h.Value = append(h.Value, e)
+type SVCBIPv6Hint struct {
+	Hint []net.IP
+}
+
+func (*SVCBIPv6Hint) Key() SVCBKey { return SVCB_IPV6HINT }
+func (s *SVCBIPv6Hint) len() int   { return 16 * len(s.Hint) }
+
+func (s *SVCBIPv6Hint) pack() ([]byte, error) {
+	b := make([]byte, 0, 16*len(s.Hint))
+	for _, e := range s.Hint {
+		if len(e) != net.IPv6len || e.To4() != nil {
+			return nil, errors.New("dns: svcbipv6hint: expected ipv6, hint is ipv4")
+		}
+		b = append(b, e...)
+	}
+	return b, nil
+}
+
+func (s *SVCBIPv6Hint) unpack(b []byte) error {
+	if len(b) == 0 || len(b)%16 != 0 {
+		return errors.New("dns: svcbipv6hint: ipv6 address byte array length not a multiple of 16")
+	}
+	x := make([]net.IP, 0, len(b)/16)
+	for i := 0; i < len(b); i += 16 {
+		ip := net.IP(b[i : i+16])
+		if ip.To4() != nil {
+			return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4")
+		}
+		x = append(x, ip)
+	}
+	s.Hint = x
+	return nil
+}
+
+func (s *SVCBIPv6Hint) String() string {
+	str := make([]string, len(s.Hint))
+	for i, e := range s.Hint {
+		if x := e.To4(); x != nil {
+			return "<nil>"
+		}
+		str[i] = e.String()
+	}
+	return strings.Join(str, ",")
+}
+
+func (s *SVCBIPv6Hint) parse(b string) error {
+	if strings.Contains(b, ".") {
+		return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4")
+	}
+	str := strings.Split(b, ",")
+	dst := make([]net.IP, len(str))
+	for i, e := range str {
+		ip := net.ParseIP(e)
+		if ip == nil {
+			return errors.New("dns: svcbipv6hint: bad ip")
+		}
+		dst[i] = ip
+	}
+	s.Hint = dst
+	return nil
+}
+
+func (s *SVCBIPv6Hint) copy() SVCBKeyValue {
+	hint := make([]net.IP, len(s.Hint))
+	for i, ip := range s.Hint {
+		hint[i] = copyIP(ip)
+	}
+
+	return &SVCBIPv6Hint{
+		Hint: hint,
+	}
+}
+
+// SVCBLocal pair is intended for experimental/private use. The key is recommended
+// to be in the range [SVCB_PRIVATE_LOWER, SVCB_PRIVATE_UPPER].
+// Basic use pattern for creating a keyNNNNN option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBLocal)
+//	e.KeyCode = 65400
+//	e.Data = []byte("abc")
+//	h.Value = append(h.Value, e)
+type SVCBLocal struct {
+	KeyCode SVCBKey // Never 65535 or any assigned keys.
+	Data    []byte  // All byte sequences are allowed.
+}
+
+func (s *SVCBLocal) Key() SVCBKey          { return s.KeyCode }
+func (s *SVCBLocal) pack() ([]byte, error) { return append([]byte(nil), s.Data...), nil }
+func (s *SVCBLocal) len() int              { return len(s.Data) }
+
+func (s *SVCBLocal) unpack(b []byte) error {
+	s.Data = append([]byte(nil), b...)
+	return nil
+}
+
+func (s *SVCBLocal) String() string {
+	var str strings.Builder
+	str.Grow(4 * len(s.Data))
+	for _, e := range s.Data {
+		if ' ' <= e && e <= '~' {
+			switch e {
+			case '"', ';', ' ', '\\':
+				str.WriteByte('\\')
+				str.WriteByte(e)
+			default:
+				str.WriteByte(e)
+			}
+		} else {
+			str.WriteString(escapeByte(e))
+		}
+	}
+	return str.String()
+}
+
+func (s *SVCBLocal) parse(b string) error {
+	data := make([]byte, 0, len(b))
+	for i := 0; i < len(b); {
+		if b[i] != '\\' {
+			data = append(data, b[i])
+			i++
+			continue
+		}
+		if i+1 == len(b) {
+			return errors.New("dns: svcblocal: svcb private/experimental key escape unterminated")
+		}
+		if isDigit(b[i+1]) {
+			if i+3 < len(b) && isDigit(b[i+2]) && isDigit(b[i+3]) {
+				a, err := strconv.ParseUint(b[i+1:i+4], 10, 8)
+				if err == nil {
+					i += 4
+					data = append(data, byte(a))
+					continue
+				}
+			}
+			return errors.New("dns: svcblocal: svcb private/experimental key bad escaped octet")
+		} else {
+			data = append(data, b[i+1])
+			i += 2
+		}
+	}
+	s.Data = data
+	return nil
+}
+
+func (s *SVCBLocal) copy() SVCBKeyValue {
+	return &SVCBLocal{s.KeyCode,
+		append([]byte(nil), s.Data...),
+	}
+}
+
+func (rr *SVCB) String() string {
+	s := rr.Hdr.String() +
+		strconv.Itoa(int(rr.Priority)) + " " +
+		sprintName(rr.Target)
+	for _, e := range rr.Value {
+		s += " " + e.Key().String() + "=\"" + e.String() + "\""
+	}
+	return s
+}
+
+// areSVCBPairArraysEqual checks if SVCBKeyValue arrays are equal after sorting their
+// copies. arrA and arrB have equal lengths, otherwise zduplicate.go wouldn't call this function.
+func areSVCBPairArraysEqual(a []SVCBKeyValue, b []SVCBKeyValue) bool {
+	a = append([]SVCBKeyValue(nil), a...)
+	b = append([]SVCBKeyValue(nil), b...)
+	sort.Slice(a, func(i, j int) bool { return a[i].Key() < a[j].Key() })
+	sort.Slice(b, func(i, j int) bool { return b[i].Key() < b[j].Key() })
+	for i, e := range a {
+		if e.Key() != b[i].Key() {
+			return false
+		}
+		b1, err1 := e.pack()
+		b2, err2 := b[i].pack()
+		if err1 != nil || err2 != nil || !bytes.Equal(b1, b2) {
+			return false
+		}
+	}
+	return true
+}

+ 99 - 59
vendor/github.com/miekg/dns/tsig.go

@@ -2,7 +2,6 @@ package dns
 
 
 import (
 import (
 	"crypto/hmac"
 	"crypto/hmac"
-	"crypto/md5"
 	"crypto/sha1"
 	"crypto/sha1"
 	"crypto/sha256"
 	"crypto/sha256"
 	"crypto/sha512"
 	"crypto/sha512"
@@ -16,12 +15,65 @@ import (
 
 
 // HMAC hashing codes. These are transmitted as domain names.
 // HMAC hashing codes. These are transmitted as domain names.
 const (
 const (
-	HmacMD5    = "hmac-md5.sig-alg.reg.int."
 	HmacSHA1   = "hmac-sha1."
 	HmacSHA1   = "hmac-sha1."
+	HmacSHA224 = "hmac-sha224."
 	HmacSHA256 = "hmac-sha256."
 	HmacSHA256 = "hmac-sha256."
+	HmacSHA384 = "hmac-sha384."
 	HmacSHA512 = "hmac-sha512."
 	HmacSHA512 = "hmac-sha512."
+
+	HmacMD5 = "hmac-md5.sig-alg.reg.int." // Deprecated: HmacMD5 is no longer supported.
 )
 )
 
 
+// TsigProvider provides the API to plug-in a custom TSIG implementation.
+type TsigProvider interface {
+	// Generate is passed the DNS message to be signed and the partial TSIG RR. It returns the signature and nil, otherwise an error.
+	Generate(msg []byte, t *TSIG) ([]byte, error)
+	// Verify is passed the DNS message to be verified and the TSIG RR. If the signature is valid it will return nil, otherwise an error.
+	Verify(msg []byte, t *TSIG) error
+}
+
+type tsigHMACProvider string
+
+func (key tsigHMACProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
+	// If we barf here, the caller is to blame
+	rawsecret, err := fromBase64([]byte(key))
+	if err != nil {
+		return nil, err
+	}
+	var h hash.Hash
+	switch CanonicalName(t.Algorithm) {
+	case HmacSHA1:
+		h = hmac.New(sha1.New, rawsecret)
+	case HmacSHA224:
+		h = hmac.New(sha256.New224, rawsecret)
+	case HmacSHA256:
+		h = hmac.New(sha256.New, rawsecret)
+	case HmacSHA384:
+		h = hmac.New(sha512.New384, rawsecret)
+	case HmacSHA512:
+		h = hmac.New(sha512.New, rawsecret)
+	default:
+		return nil, ErrKeyAlg
+	}
+	h.Write(msg)
+	return h.Sum(nil), nil
+}
+
+func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
+	b, err := key.Generate(msg, t)
+	if err != nil {
+		return err
+	}
+	mac, err := hex.DecodeString(t.MAC)
+	if err != nil {
+		return err
+	}
+	if !hmac.Equal(b, mac) {
+		return ErrSig
+	}
+	return nil
+}
+
 // TSIG is the RR the holds the transaction signature of a message.
 // TSIG is the RR the holds the transaction signature of a message.
 // See RFC 2845 and RFC 4635.
 // See RFC 2845 and RFC 4635.
 type TSIG struct {
 type TSIG struct {
@@ -54,8 +106,8 @@ func (rr *TSIG) String() string {
 	return s
 	return s
 }
 }
 
 
-func (rr *TSIG) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on TSIG")
+func (*TSIG) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "TSIG records do not have a presentation format"}
 }
 }
 
 
 // The following values must be put in wireformat, so that the MAC can be calculated.
 // The following values must be put in wireformat, so that the MAC can be calculated.
@@ -96,14 +148,13 @@ type timerWireFmt struct {
 // timersOnly is false.
 // timersOnly is false.
 // If something goes wrong an error is returned, otherwise it is nil.
 // If something goes wrong an error is returned, otherwise it is nil.
 func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) {
 func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) {
+	return tsigGenerateProvider(m, tsigHMACProvider(secret), requestMAC, timersOnly)
+}
+
+func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) {
 	if m.IsTsig() == nil {
 	if m.IsTsig() == nil {
 		panic("dns: TSIG not last RR in additional")
 		panic("dns: TSIG not last RR in additional")
 	}
 	}
-	// If we barf here, the caller is to blame
-	rawsecret, err := fromBase64([]byte(secret))
-	if err != nil {
-		return nil, "", err
-	}
 
 
 	rr := m.Extra[len(m.Extra)-1].(*TSIG)
 	rr := m.Extra[len(m.Extra)-1].(*TSIG)
 	m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
 	m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
@@ -111,32 +162,21 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
 	if err != nil {
 	if err != nil {
 		return nil, "", err
 		return nil, "", err
 	}
 	}
-	buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
+	buf, err := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
+	if err != nil {
+		return nil, "", err
+	}
 
 
 	t := new(TSIG)
 	t := new(TSIG)
-	var h hash.Hash
-	switch CanonicalName(rr.Algorithm) {
-	case HmacMD5:
-		h = hmac.New(md5.New, rawsecret)
-	case HmacSHA1:
-		h = hmac.New(sha1.New, rawsecret)
-	case HmacSHA256:
-		h = hmac.New(sha256.New, rawsecret)
-	case HmacSHA512:
-		h = hmac.New(sha512.New, rawsecret)
-	default:
-		return nil, "", ErrKeyAlg
+	// Copy all TSIG fields except MAC and its size, which are filled using the computed digest.
+	*t = *rr
+	mac, err := provider.Generate(buf, rr)
+	if err != nil {
+		return nil, "", err
 	}
 	}
-	h.Write(buf)
-	t.MAC = hex.EncodeToString(h.Sum(nil))
+	t.MAC = hex.EncodeToString(mac)
 	t.MACSize = uint16(len(t.MAC) / 2) // Size is half!
 	t.MACSize = uint16(len(t.MAC) / 2) // Size is half!
 
 
-	t.Hdr = RR_Header{Name: rr.Hdr.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0}
-	t.Fudge = rr.Fudge
-	t.TimeSigned = rr.TimeSigned
-	t.Algorithm = rr.Algorithm
-	t.OrigId = m.Id
-
 	tbuf := make([]byte, Len(t))
 	tbuf := make([]byte, Len(t))
 	off, err := PackRR(t, tbuf, 0, nil, false)
 	off, err := PackRR(t, tbuf, 0, nil, false)
 	if err != nil {
 	if err != nil {
@@ -153,26 +193,34 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
 // If the signature does not validate err contains the
 // If the signature does not validate err contains the
 // error, otherwise it is nil.
 // error, otherwise it is nil.
 func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
 func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
-	rawsecret, err := fromBase64([]byte(secret))
-	if err != nil {
-		return err
-	}
+	return tsigVerify(msg, tsigHMACProvider(secret), requestMAC, timersOnly, uint64(time.Now().Unix()))
+}
+
+func tsigVerifyProvider(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool) error {
+	return tsigVerify(msg, provider, requestMAC, timersOnly, uint64(time.Now().Unix()))
+}
+
+// actual implementation of TsigVerify, taking the current time ('now') as a parameter for the convenience of tests.
+func tsigVerify(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool, now uint64) error {
 	// Strip the TSIG from the incoming msg
 	// Strip the TSIG from the incoming msg
 	stripped, tsig, err := stripTsig(msg)
 	stripped, tsig, err := stripTsig(msg)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	msgMAC, err := hex.DecodeString(tsig.MAC)
+	buf, err := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
+	if err := provider.Verify(buf, tsig); err != nil {
+		return err
+	}
 
 
 	// Fudge factor works both ways. A message can arrive before it was signed because
 	// Fudge factor works both ways. A message can arrive before it was signed because
 	// of clock skew.
 	// of clock skew.
-	now := uint64(time.Now().Unix())
+	// We check this after verifying the signature, following draft-ietf-dnsop-rfc2845bis
+	// instead of RFC2845, in order to prevent a security vulnerability as reported in CVE-2017-3142/3143.
 	ti := now - tsig.TimeSigned
 	ti := now - tsig.TimeSigned
 	if now < tsig.TimeSigned {
 	if now < tsig.TimeSigned {
 		ti = tsig.TimeSigned - now
 		ti = tsig.TimeSigned - now
@@ -181,28 +229,11 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
 		return ErrTime
 		return ErrTime
 	}
 	}
 
 
-	var h hash.Hash
-	switch CanonicalName(tsig.Algorithm) {
-	case HmacMD5:
-		h = hmac.New(md5.New, rawsecret)
-	case HmacSHA1:
-		h = hmac.New(sha1.New, rawsecret)
-	case HmacSHA256:
-		h = hmac.New(sha256.New, rawsecret)
-	case HmacSHA512:
-		h = hmac.New(sha512.New, rawsecret)
-	default:
-		return ErrKeyAlg
-	}
-	h.Write(buf)
-	if !hmac.Equal(h.Sum(nil), msgMAC) {
-		return ErrSig
-	}
 	return nil
 	return nil
 }
 }
 
 
 // Create a wiredata buffer for the MAC calculation.
 // Create a wiredata buffer for the MAC calculation.
-func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte {
+func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) ([]byte, error) {
 	var buf []byte
 	var buf []byte
 	if rr.TimeSigned == 0 {
 	if rr.TimeSigned == 0 {
 		rr.TimeSigned = uint64(time.Now().Unix())
 		rr.TimeSigned = uint64(time.Now().Unix())
@@ -219,7 +250,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 		m.MACSize = uint16(len(requestMAC) / 2)
 		m.MACSize = uint16(len(requestMAC) / 2)
 		m.MAC = requestMAC
 		m.MAC = requestMAC
 		buf = make([]byte, len(requestMAC)) // long enough
 		buf = make([]byte, len(requestMAC)) // long enough
-		n, _ := packMacWire(m, buf)
+		n, err := packMacWire(m, buf)
+		if err != nil {
+			return nil, err
+		}
 		buf = buf[:n]
 		buf = buf[:n]
 	}
 	}
 
 
@@ -228,7 +262,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 		tsig := new(timerWireFmt)
 		tsig := new(timerWireFmt)
 		tsig.TimeSigned = rr.TimeSigned
 		tsig.TimeSigned = rr.TimeSigned
 		tsig.Fudge = rr.Fudge
 		tsig.Fudge = rr.Fudge
-		n, _ := packTimerWire(tsig, tsigvar)
+		n, err := packTimerWire(tsig, tsigvar)
+		if err != nil {
+			return nil, err
+		}
 		tsigvar = tsigvar[:n]
 		tsigvar = tsigvar[:n]
 	} else {
 	} else {
 		tsig := new(tsigWireFmt)
 		tsig := new(tsigWireFmt)
@@ -241,7 +278,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 		tsig.Error = rr.Error
 		tsig.Error = rr.Error
 		tsig.OtherLen = rr.OtherLen
 		tsig.OtherLen = rr.OtherLen
 		tsig.OtherData = rr.OtherData
 		tsig.OtherData = rr.OtherData
-		n, _ := packTsigWire(tsig, tsigvar)
+		n, err := packTsigWire(tsig, tsigvar)
+		if err != nil {
+			return nil, err
+		}
 		tsigvar = tsigvar[:n]
 		tsigvar = tsigvar[:n]
 	}
 	}
 
 
@@ -251,7 +291,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 	} else {
 	} else {
 		buf = append(msgbuf, tsigvar...)
 		buf = append(msgbuf, tsigvar...)
 	}
 	}
-	return buf
+	return buf, nil
 }
 }
 
 
 // Strip the TSIG from the raw message.
 // Strip the TSIG from the raw message.

+ 79 - 50
vendor/github.com/miekg/dns/types.go

@@ -81,6 +81,9 @@ const (
 	TypeCDNSKEY    uint16 = 60
 	TypeCDNSKEY    uint16 = 60
 	TypeOPENPGPKEY uint16 = 61
 	TypeOPENPGPKEY uint16 = 61
 	TypeCSYNC      uint16 = 62
 	TypeCSYNC      uint16 = 62
+	TypeZONEMD     uint16 = 63
+	TypeSVCB       uint16 = 64
+	TypeHTTPS      uint16 = 65
 	TypeSPF        uint16 = 99
 	TypeSPF        uint16 = 99
 	TypeUINFO      uint16 = 100
 	TypeUINFO      uint16 = 100
 	TypeUID        uint16 = 101
 	TypeUID        uint16 = 101
@@ -148,6 +151,14 @@ const (
 	OpcodeUpdate = 5
 	OpcodeUpdate = 5
 )
 )
 
 
+// Used in ZONEMD https://tools.ietf.org/html/rfc8976
+const (
+	ZoneMDSchemeSimple = 1
+
+	ZoneMDHashAlgSHA384 = 1
+	ZoneMDHashAlgSHA512 = 2
+)
+
 // Header is the wire format for the DNS packet header.
 // Header is the wire format for the DNS packet header.
 type Header struct {
 type Header struct {
 	Id                                 uint16
 	Id                                 uint16
@@ -243,8 +254,8 @@ type ANY struct {
 
 
 func (rr *ANY) String() string { return rr.Hdr.String() }
 func (rr *ANY) String() string { return rr.Hdr.String() }
 
 
-func (rr *ANY) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on ANY")
+func (*ANY) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "ANY records do not have a presentation format"}
 }
 }
 
 
 // NULL RR. See RFC 1035.
 // NULL RR. See RFC 1035.
@@ -258,8 +269,8 @@ func (rr *NULL) String() string {
 	return ";" + rr.Hdr.String() + rr.Data
 	return ";" + rr.Hdr.String() + rr.Data
 }
 }
 
 
-func (rr *NULL) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on NULL")
+func (*NULL) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "NULL records do not have a presentation format"}
 }
 }
 
 
 // CNAME RR. See RFC 1034.
 // CNAME RR. See RFC 1034.
@@ -445,45 +456,38 @@ func sprintName(s string) string {
 	var dst strings.Builder
 	var dst strings.Builder
 
 
 	for i := 0; i < len(s); {
 	for i := 0; i < len(s); {
-		if i+1 < len(s) && s[i] == '\\' && s[i+1] == '.' {
+		if s[i] == '.' {
 			if dst.Len() != 0 {
 			if dst.Len() != 0 {
-				dst.WriteString(s[i : i+2])
+				dst.WriteByte('.')
 			}
 			}
-			i += 2
+			i++
 			continue
 			continue
 		}
 		}
 
 
 		b, n := nextByte(s, i)
 		b, n := nextByte(s, i)
 		if n == 0 {
 		if n == 0 {
-			i++
-			continue
-		}
-		if b == '.' {
-			if dst.Len() != 0 {
-				dst.WriteByte('.')
+			// Drop "dangling" incomplete escapes.
+			if dst.Len() == 0 {
+				return s[:i]
 			}
 			}
-			i += n
-			continue
+			break
 		}
 		}
-		switch b {
-		case ' ', '\'', '@', ';', '(', ')', '"', '\\': // additional chars to escape
+		if isDomainNameLabelSpecial(b) {
 			if dst.Len() == 0 {
 			if dst.Len() == 0 {
 				dst.Grow(len(s) * 2)
 				dst.Grow(len(s) * 2)
 				dst.WriteString(s[:i])
 				dst.WriteString(s[:i])
 			}
 			}
 			dst.WriteByte('\\')
 			dst.WriteByte('\\')
 			dst.WriteByte(b)
 			dst.WriteByte(b)
-		default:
-			if ' ' <= b && b <= '~' {
-				if dst.Len() != 0 {
-					dst.WriteByte(b)
-				}
-			} else {
-				if dst.Len() == 0 {
-					dst.Grow(len(s) * 2)
-					dst.WriteString(s[:i])
-				}
-				dst.WriteString(escapeByte(b))
+		} else if b < ' ' || b > '~' { // unprintable, use \DDD
+			if dst.Len() == 0 {
+				dst.Grow(len(s) * 2)
+				dst.WriteString(s[:i])
+			}
+			dst.WriteString(escapeByte(b))
+		} else {
+			if dst.Len() != 0 {
+				dst.WriteByte(b)
 			}
 			}
 		}
 		}
 		i += n
 		i += n
@@ -506,15 +510,10 @@ func sprintTxtOctet(s string) string {
 		}
 		}
 
 
 		b, n := nextByte(s, i)
 		b, n := nextByte(s, i)
-		switch {
-		case n == 0:
+		if n == 0 {
 			i++ // dangling back slash
 			i++ // dangling back slash
-		case b == '.':
-			dst.WriteByte('.')
-		case b < ' ' || b > '~':
-			dst.WriteString(escapeByte(b))
-		default:
-			dst.WriteByte(b)
+		} else {
+			writeTXTStringByte(&dst, b)
 		}
 		}
 		i += n
 		i += n
 	}
 	}
@@ -590,6 +589,17 @@ func escapeByte(b byte) string {
 	return escapedByteLarge[int(b)*4 : int(b)*4+4]
 	return escapedByteLarge[int(b)*4 : int(b)*4+4]
 }
 }
 
 
+// isDomainNameLabelSpecial returns true if
+// a domain name label byte should be prefixed
+// with an escaping backslash.
+func isDomainNameLabelSpecial(b byte) bool {
+	switch b {
+	case '.', ' ', '\'', '@', ';', '(', ')', '"', '\\':
+		return true
+	}
+	return false
+}
+
 func nextByte(s string, offset int) (byte, int) {
 func nextByte(s string, offset int) (byte, int) {
 	if offset >= len(s) {
 	if offset >= len(s) {
 		return 0, 0
 		return 0, 0
@@ -1121,6 +1131,7 @@ type URI struct {
 	Target   string `dns:"octet"`
 	Target   string `dns:"octet"`
 }
 }
 
 
+// rr.Target to be parsed as a sequence of character encoded octets according to RFC 3986
 func (rr *URI) String() string {
 func (rr *URI) String() string {
 	return rr.Hdr.String() + strconv.Itoa(int(rr.Priority)) +
 	return rr.Hdr.String() + strconv.Itoa(int(rr.Priority)) +
 		" " + strconv.Itoa(int(rr.Weight)) + " " + sprintTxtOctet(rr.Target)
 		" " + strconv.Itoa(int(rr.Weight)) + " " + sprintTxtOctet(rr.Target)
@@ -1282,6 +1293,7 @@ type CAA struct {
 	Value string `dns:"octet"`
 	Value string `dns:"octet"`
 }
 }
 
 
+// rr.Value Is the character-string encoding of the value field as specified in RFC 1035, Section 5.1.
 func (rr *CAA) String() string {
 func (rr *CAA) String() string {
 	return rr.Hdr.String() + strconv.Itoa(int(rr.Flag)) + " " + rr.Tag + " " + sprintTxtOctet(rr.Value)
 	return rr.Hdr.String() + strconv.Itoa(int(rr.Flag)) + " " + rr.Tag + " " + sprintTxtOctet(rr.Value)
 }
 }
@@ -1358,6 +1370,23 @@ func (rr *CSYNC) len(off int, compression map[string]struct{}) int {
 	return l
 	return l
 }
 }
 
 
+// ZONEMD RR, from draft-ietf-dnsop-dns-zone-digest
+type ZONEMD struct {
+	Hdr    RR_Header
+	Serial uint32
+	Scheme uint8
+	Hash   uint8
+	Digest string `dns:"hex"`
+}
+
+func (rr *ZONEMD) String() string {
+	return rr.Hdr.String() +
+		strconv.Itoa(int(rr.Serial)) +
+		" " + strconv.Itoa(int(rr.Scheme)) +
+		" " + strconv.Itoa(int(rr.Hash)) +
+		" " + rr.Digest
+}
+
 // APL RR. See RFC 3123.
 // APL RR. See RFC 3123.
 type APL struct {
 type APL struct {
 	Hdr      RR_Header
 	Hdr      RR_Header
@@ -1384,13 +1413,13 @@ func (rr *APL) String() string {
 }
 }
 
 
 // str returns presentation form of the APL prefix.
 // str returns presentation form of the APL prefix.
-func (p *APLPrefix) str() string {
+func (a *APLPrefix) str() string {
 	var sb strings.Builder
 	var sb strings.Builder
-	if p.Negation {
+	if a.Negation {
 		sb.WriteByte('!')
 		sb.WriteByte('!')
 	}
 	}
 
 
-	switch len(p.Network.IP) {
+	switch len(a.Network.IP) {
 	case net.IPv4len:
 	case net.IPv4len:
 		sb.WriteByte('1')
 		sb.WriteByte('1')
 	case net.IPv6len:
 	case net.IPv6len:
@@ -1399,20 +1428,20 @@ func (p *APLPrefix) str() string {
 
 
 	sb.WriteByte(':')
 	sb.WriteByte(':')
 
 
-	switch len(p.Network.IP) {
+	switch len(a.Network.IP) {
 	case net.IPv4len:
 	case net.IPv4len:
-		sb.WriteString(p.Network.IP.String())
+		sb.WriteString(a.Network.IP.String())
 	case net.IPv6len:
 	case net.IPv6len:
 		// add prefix for IPv4-mapped IPv6
 		// add prefix for IPv4-mapped IPv6
-		if v4 := p.Network.IP.To4(); v4 != nil {
+		if v4 := a.Network.IP.To4(); v4 != nil {
 			sb.WriteString("::ffff:")
 			sb.WriteString("::ffff:")
 		}
 		}
-		sb.WriteString(p.Network.IP.String())
+		sb.WriteString(a.Network.IP.String())
 	}
 	}
 
 
 	sb.WriteByte('/')
 	sb.WriteByte('/')
 
 
-	prefix, _ := p.Network.Mask.Size()
+	prefix, _ := a.Network.Mask.Size()
 	sb.WriteString(strconv.Itoa(prefix))
 	sb.WriteString(strconv.Itoa(prefix))
 
 
 	return sb.String()
 	return sb.String()
@@ -1426,17 +1455,17 @@ func (a *APLPrefix) equals(b *APLPrefix) bool {
 }
 }
 
 
 // copy returns a copy of the APL prefix.
 // copy returns a copy of the APL prefix.
-func (p *APLPrefix) copy() APLPrefix {
+func (a *APLPrefix) copy() APLPrefix {
 	return APLPrefix{
 	return APLPrefix{
-		Negation: p.Negation,
-		Network:  copyNet(p.Network),
+		Negation: a.Negation,
+		Network:  copyNet(a.Network),
 	}
 	}
 }
 }
 
 
 // len returns size of the prefix in wire format.
 // len returns size of the prefix in wire format.
-func (p *APLPrefix) len() int {
+func (a *APLPrefix) len() int {
 	// 4-byte header and the network address prefix (see Section 4 of RFC 3123)
 	// 4-byte header and the network address prefix (see Section 4 of RFC 3123)
-	prefix, _ := p.Network.Mask.Size()
+	prefix, _ := a.Network.Mask.Size()
 	return 4 + (prefix+7)/8
 	return 4 + (prefix+7)/8
 }
 }
 
 
@@ -1469,7 +1498,7 @@ func StringToTime(s string) (uint32, error) {
 
 
 // saltToString converts a NSECX salt to uppercase and returns "-" when it is empty.
 // saltToString converts a NSECX salt to uppercase and returns "-" when it is empty.
 func saltToString(s string) string {
 func saltToString(s string) string {
-	if len(s) == 0 {
+	if s == "" {
 		return "-"
 		return "-"
 	}
 	}
 	return strings.ToUpper(s)
 	return strings.ToUpper(s)

+ 21 - 5
vendor/github.com/miekg/dns/types_generate.go

@@ -72,6 +72,9 @@ func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
 	if !ok {
 	if !ok {
 		return nil, false
 		return nil, false
 	}
 	}
+	if st.NumFields() == 0 {
+		return nil, false
+	}
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 		return st, false
 		return st, false
 	}
 	}
@@ -181,6 +184,8 @@ func main() {
 					o("for _, x := range rr.%s { l += len(x) + 1 }\n")
 					o("for _, x := range rr.%s { l += len(x) + 1 }\n")
 				case `dns:"apl"`:
 				case `dns:"apl"`:
 					o("for _, x := range rr.%s { l += x.len() }\n")
 					o("for _, x := range rr.%s { l += x.len() }\n")
+				case `dns:"pairs"`:
+					o("for _, x := range rr.%s { l += 4 + int(x.len()) }\n")
 				default:
 				default:
 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
 				}
 				}
@@ -241,11 +246,15 @@ func main() {
 	for _, name := range namedTypes {
 	for _, name := range namedTypes {
 		o := scope.Lookup(name)
 		o := scope.Lookup(name)
 		st, isEmbedded := getTypeStruct(o.Type(), scope)
 		st, isEmbedded := getTypeStruct(o.Type(), scope)
+		fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
+		fields := make([]string, 0, st.NumFields())
 		if isEmbedded {
 		if isEmbedded {
-			continue
+			a, _ := o.Type().Underlying().(*types.Struct)
+			parent := a.Field(0).Name()
+			fields = append(fields, "*rr."+parent+".copy().(*"+parent+")")
+			goto WriteCopy
 		}
 		}
-		fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
-		fields := []string{"rr.Hdr"}
+		fields = append(fields, "rr.Hdr")
 		for i := 1; i < st.NumFields(); i++ {
 		for i := 1; i < st.NumFields(); i++ {
 			f := st.Field(i).Name()
 			f := st.Field(i).Name()
 			if sl, ok := st.Field(i).Type().(*types.Slice); ok {
 			if sl, ok := st.Field(i).Type().(*types.Slice); ok {
@@ -263,8 +272,14 @@ func main() {
 					continue
 					continue
 				}
 				}
 				if t == "APLPrefix" {
 				if t == "APLPrefix" {
-					fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i := range rr.%s {\n %s[i] = rr.%s[i].copy()\n}\n",
-						f, t, f, f, f, f)
+					fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n",
+						f, t, f, f, f)
+					fields = append(fields, f)
+					continue
+				}
+				if t == "SVCBKeyValue" {
+					fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n",
+						f, t, f, f, f)
 					fields = append(fields, f)
 					fields = append(fields, f)
 					continue
 					continue
 				}
 				}
@@ -279,6 +294,7 @@ func main() {
 			}
 			}
 			fields = append(fields, "rr."+f)
 			fields = append(fields, "rr."+f)
 		}
 		}
+	WriteCopy:
 		fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ","))
 		fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ","))
 		fmt.Fprintf(b, "}\n")
 		fmt.Fprintf(b, "}\n")
 	}
 	}

+ 3 - 1
vendor/github.com/miekg/dns/update.go

@@ -32,7 +32,9 @@ func (u *Msg) Used(rr []RR) {
 		u.Answer = make([]RR, 0, len(rr))
 		u.Answer = make([]RR, 0, len(rr))
 	}
 	}
 	for _, r := range rr {
 	for _, r := range rr {
-		r.Header().Class = u.Question[0].Qclass
+		hdr := r.Header()
+		hdr.Class = u.Question[0].Qclass
+		hdr.Ttl = 0
 		u.Answer = append(u.Answer, r)
 		u.Answer = append(u.Answer, r)
 	}
 	}
 }
 }

+ 1 - 1
vendor/github.com/miekg/dns/version.go

@@ -3,7 +3,7 @@ package dns
 import "fmt"
 import "fmt"
 
 
 // Version is current version of this library.
 // Version is current version of this library.
-var Version = v{1, 1, 29}
+var Version = v{1, 1, 43}
 
 
 // v holds the version of this library.
 // v holds the version of this library.
 type v struct {
 type v struct {

+ 183 - 0
vendor/github.com/miekg/dns/zduplicate.go

@@ -104,6 +104,48 @@ func (r1 *CAA) isDuplicate(_r2 RR) bool {
 	return true
 	return true
 }
 }
 
 
+func (r1 *CDNSKEY) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*CDNSKEY)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Flags != r2.Flags {
+		return false
+	}
+	if r1.Protocol != r2.Protocol {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.PublicKey != r2.PublicKey {
+		return false
+	}
+	return true
+}
+
+func (r1 *CDS) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*CDS)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.KeyTag != r2.KeyTag {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.DigestType != r2.DigestType {
+		return false
+	}
+	if r1.Digest != r2.Digest {
+		return false
+	}
+	return true
+}
+
 func (r1 *CERT) isDuplicate(_r2 RR) bool {
 func (r1 *CERT) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*CERT)
 	r2, ok := _r2.(*CERT)
 	if !ok {
 	if !ok {
@@ -172,6 +214,27 @@ func (r1 *DHCID) isDuplicate(_r2 RR) bool {
 	return true
 	return true
 }
 }
 
 
+func (r1 *DLV) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*DLV)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.KeyTag != r2.KeyTag {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.DigestType != r2.DigestType {
+		return false
+	}
+	if r1.Digest != r2.Digest {
+		return false
+	}
+	return true
+}
+
 func (r1 *DNAME) isDuplicate(_r2 RR) bool {
 func (r1 *DNAME) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*DNAME)
 	r2, ok := _r2.(*DNAME)
 	if !ok {
 	if !ok {
@@ -339,6 +402,48 @@ func (r1 *HIP) isDuplicate(_r2 RR) bool {
 	return true
 	return true
 }
 }
 
 
+func (r1 *HTTPS) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*HTTPS)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Priority != r2.Priority {
+		return false
+	}
+	if !isDuplicateName(r1.Target, r2.Target) {
+		return false
+	}
+	if len(r1.Value) != len(r2.Value) {
+		return false
+	}
+	if !areSVCBPairArraysEqual(r1.Value, r2.Value) {
+		return false
+	}
+	return true
+}
+
+func (r1 *KEY) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*KEY)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Flags != r2.Flags {
+		return false
+	}
+	if r1.Protocol != r2.Protocol {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.PublicKey != r2.PublicKey {
+		return false
+	}
+	return true
+}
+
 func (r1 *KX) isDuplicate(_r2 RR) bool {
 func (r1 *KX) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*KX)
 	r2, ok := _r2.(*KX)
 	if !ok {
 	if !ok {
@@ -849,6 +954,42 @@ func (r1 *RT) isDuplicate(_r2 RR) bool {
 	return true
 	return true
 }
 }
 
 
+func (r1 *SIG) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*SIG)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.TypeCovered != r2.TypeCovered {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.Labels != r2.Labels {
+		return false
+	}
+	if r1.OrigTtl != r2.OrigTtl {
+		return false
+	}
+	if r1.Expiration != r2.Expiration {
+		return false
+	}
+	if r1.Inception != r2.Inception {
+		return false
+	}
+	if r1.KeyTag != r2.KeyTag {
+		return false
+	}
+	if !isDuplicateName(r1.SignerName, r2.SignerName) {
+		return false
+	}
+	if r1.Signature != r2.Signature {
+		return false
+	}
+	return true
+}
+
 func (r1 *SMIMEA) isDuplicate(_r2 RR) bool {
 func (r1 *SMIMEA) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*SMIMEA)
 	r2, ok := _r2.(*SMIMEA)
 	if !ok {
 	if !ok {
@@ -956,6 +1097,27 @@ func (r1 *SSHFP) isDuplicate(_r2 RR) bool {
 	return true
 	return true
 }
 }
 
 
+func (r1 *SVCB) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*SVCB)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Priority != r2.Priority {
+		return false
+	}
+	if !isDuplicateName(r1.Target, r2.Target) {
+		return false
+	}
+	if len(r1.Value) != len(r2.Value) {
+		return false
+	}
+	if !areSVCBPairArraysEqual(r1.Value, r2.Value) {
+		return false
+	}
+	return true
+}
+
 func (r1 *TA) isDuplicate(_r2 RR) bool {
 func (r1 *TA) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*TA)
 	r2, ok := _r2.(*TA)
 	if !ok {
 	if !ok {
@@ -1155,3 +1317,24 @@ func (r1 *X25) isDuplicate(_r2 RR) bool {
 	}
 	}
 	return true
 	return true
 }
 }
+
+func (r1 *ZONEMD) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*ZONEMD)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Serial != r2.Serial {
+		return false
+	}
+	if r1.Scheme != r2.Scheme {
+		return false
+	}
+	if r1.Hash != r2.Hash {
+		return false
+	}
+	if r1.Digest != r2.Digest {
+		return false
+	}
+	return true
+}

+ 134 - 0
vendor/github.com/miekg/dns/zmsg.go

@@ -316,6 +316,22 @@ func (rr *HIP) pack(msg []byte, off int, compression compressionMap, compress bo
 	return off, nil
 	return off, nil
 }
 }
 
 
+func (rr *HTTPS) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+	off, err = packUint16(rr.Priority, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDomainName(rr.Target, msg, off, compression, false)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDataSVCB(rr.Value, msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *KEY) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 func (rr *KEY) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 	off, err = packUint16(rr.Flags, msg, off)
 	off, err = packUint16(rr.Flags, msg, off)
 	if err != nil {
 	if err != nil {
@@ -906,6 +922,22 @@ func (rr *SSHFP) pack(msg []byte, off int, compression compressionMap, compress
 	return off, nil
 	return off, nil
 }
 }
 
 
+func (rr *SVCB) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+	off, err = packUint16(rr.Priority, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDomainName(rr.Target, msg, off, compression, false)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDataSVCB(rr.Value, msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *TA) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 func (rr *TA) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 	off, err = packUint16(rr.KeyTag, msg, off)
 	off, err = packUint16(rr.KeyTag, msg, off)
 	if err != nil {
 	if err != nil {
@@ -1086,6 +1118,26 @@ func (rr *X25) pack(msg []byte, off int, compression compressionMap, compress bo
 	return off, nil
 	return off, nil
 }
 }
 
 
+func (rr *ZONEMD) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+	off, err = packUint32(rr.Serial, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packUint8(rr.Scheme, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packUint8(rr.Hash, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packStringHex(rr.Digest, msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 // unpack*() functions
 // unpack*() functions
 
 
 func (rr *A) unpack(msg []byte, off int) (off1 int, err error) {
 func (rr *A) unpack(msg []byte, off int) (off1 int, err error) {
@@ -1559,6 +1611,31 @@ func (rr *HIP) unpack(msg []byte, off int) (off1 int, err error) {
 	return off, nil
 	return off, nil
 }
 }
 
 
+func (rr *HTTPS) unpack(msg []byte, off int) (off1 int, err error) {
+	rdStart := off
+	_ = rdStart
+
+	rr.Priority, off, err = unpackUint16(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Target, off, err = UnpackDomainName(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Value, off, err = unpackDataSVCB(msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *KEY) unpack(msg []byte, off int) (off1 int, err error) {
 func (rr *KEY) unpack(msg []byte, off int) (off1 int, err error) {
 	rdStart := off
 	rdStart := off
 	_ = rdStart
 	_ = rdStart
@@ -2461,6 +2538,31 @@ func (rr *SSHFP) unpack(msg []byte, off int) (off1 int, err error) {
 	return off, nil
 	return off, nil
 }
 }
 
 
+func (rr *SVCB) unpack(msg []byte, off int) (off1 int, err error) {
+	rdStart := off
+	_ = rdStart
+
+	rr.Priority, off, err = unpackUint16(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Target, off, err = UnpackDomainName(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Value, off, err = unpackDataSVCB(msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *TA) unpack(msg []byte, off int) (off1 int, err error) {
 func (rr *TA) unpack(msg []byte, off int) (off1 int, err error) {
 	rdStart := off
 	rdStart := off
 	_ = rdStart
 	_ = rdStart
@@ -2739,3 +2841,35 @@ func (rr *X25) unpack(msg []byte, off int) (off1 int, err error) {
 	}
 	}
 	return off, nil
 	return off, nil
 }
 }
+
+func (rr *ZONEMD) unpack(msg []byte, off int) (off1 int, err error) {
+	rdStart := off
+	_ = rdStart
+
+	rr.Serial, off, err = unpackUint32(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Scheme, off, err = unpackUint8(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Hash, off, err = unpackUint8(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Digest, off, err = unpackStringHex(msg, off, rdStart+int(rr.Hdr.Rdlength))
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}

+ 56 - 2
vendor/github.com/miekg/dns/ztypes.go

@@ -33,6 +33,7 @@ var TypeToRR = map[uint16]func() RR{
 	TypeGPOS:       func() RR { return new(GPOS) },
 	TypeGPOS:       func() RR { return new(GPOS) },
 	TypeHINFO:      func() RR { return new(HINFO) },
 	TypeHINFO:      func() RR { return new(HINFO) },
 	TypeHIP:        func() RR { return new(HIP) },
 	TypeHIP:        func() RR { return new(HIP) },
+	TypeHTTPS:      func() RR { return new(HTTPS) },
 	TypeKEY:        func() RR { return new(KEY) },
 	TypeKEY:        func() RR { return new(KEY) },
 	TypeKX:         func() RR { return new(KX) },
 	TypeKX:         func() RR { return new(KX) },
 	TypeL32:        func() RR { return new(L32) },
 	TypeL32:        func() RR { return new(L32) },
@@ -70,6 +71,7 @@ var TypeToRR = map[uint16]func() RR{
 	TypeSPF:        func() RR { return new(SPF) },
 	TypeSPF:        func() RR { return new(SPF) },
 	TypeSRV:        func() RR { return new(SRV) },
 	TypeSRV:        func() RR { return new(SRV) },
 	TypeSSHFP:      func() RR { return new(SSHFP) },
 	TypeSSHFP:      func() RR { return new(SSHFP) },
+	TypeSVCB:       func() RR { return new(SVCB) },
 	TypeTA:         func() RR { return new(TA) },
 	TypeTA:         func() RR { return new(TA) },
 	TypeTALINK:     func() RR { return new(TALINK) },
 	TypeTALINK:     func() RR { return new(TALINK) },
 	TypeTKEY:       func() RR { return new(TKEY) },
 	TypeTKEY:       func() RR { return new(TKEY) },
@@ -80,6 +82,7 @@ var TypeToRR = map[uint16]func() RR{
 	TypeUINFO:      func() RR { return new(UINFO) },
 	TypeUINFO:      func() RR { return new(UINFO) },
 	TypeURI:        func() RR { return new(URI) },
 	TypeURI:        func() RR { return new(URI) },
 	TypeX25:        func() RR { return new(X25) },
 	TypeX25:        func() RR { return new(X25) },
+	TypeZONEMD:     func() RR { return new(ZONEMD) },
 }
 }
 
 
 // TypeToString is a map of strings for each RR type.
 // TypeToString is a map of strings for each RR type.
@@ -110,6 +113,7 @@ var TypeToString = map[uint16]string{
 	TypeGPOS:       "GPOS",
 	TypeGPOS:       "GPOS",
 	TypeHINFO:      "HINFO",
 	TypeHINFO:      "HINFO",
 	TypeHIP:        "HIP",
 	TypeHIP:        "HIP",
+	TypeHTTPS:      "HTTPS",
 	TypeISDN:       "ISDN",
 	TypeISDN:       "ISDN",
 	TypeIXFR:       "IXFR",
 	TypeIXFR:       "IXFR",
 	TypeKEY:        "KEY",
 	TypeKEY:        "KEY",
@@ -153,6 +157,7 @@ var TypeToString = map[uint16]string{
 	TypeSPF:        "SPF",
 	TypeSPF:        "SPF",
 	TypeSRV:        "SRV",
 	TypeSRV:        "SRV",
 	TypeSSHFP:      "SSHFP",
 	TypeSSHFP:      "SSHFP",
+	TypeSVCB:       "SVCB",
 	TypeTA:         "TA",
 	TypeTA:         "TA",
 	TypeTALINK:     "TALINK",
 	TypeTALINK:     "TALINK",
 	TypeTKEY:       "TKEY",
 	TypeTKEY:       "TKEY",
@@ -164,6 +169,7 @@ var TypeToString = map[uint16]string{
 	TypeUNSPEC:     "UNSPEC",
 	TypeUNSPEC:     "UNSPEC",
 	TypeURI:        "URI",
 	TypeURI:        "URI",
 	TypeX25:        "X25",
 	TypeX25:        "X25",
+	TypeZONEMD:     "ZONEMD",
 	TypeNSAPPTR:    "NSAP-PTR",
 	TypeNSAPPTR:    "NSAP-PTR",
 }
 }
 
 
@@ -191,6 +197,7 @@ func (rr *GID) Header() *RR_Header        { return &rr.Hdr }
 func (rr *GPOS) Header() *RR_Header       { return &rr.Hdr }
 func (rr *GPOS) Header() *RR_Header       { return &rr.Hdr }
 func (rr *HINFO) Header() *RR_Header      { return &rr.Hdr }
 func (rr *HINFO) Header() *RR_Header      { return &rr.Hdr }
 func (rr *HIP) Header() *RR_Header        { return &rr.Hdr }
 func (rr *HIP) Header() *RR_Header        { return &rr.Hdr }
+func (rr *HTTPS) Header() *RR_Header      { return &rr.Hdr }
 func (rr *KEY) Header() *RR_Header        { return &rr.Hdr }
 func (rr *KEY) Header() *RR_Header        { return &rr.Hdr }
 func (rr *KX) Header() *RR_Header         { return &rr.Hdr }
 func (rr *KX) Header() *RR_Header         { return &rr.Hdr }
 func (rr *L32) Header() *RR_Header        { return &rr.Hdr }
 func (rr *L32) Header() *RR_Header        { return &rr.Hdr }
@@ -229,6 +236,7 @@ func (rr *SOA) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SPF) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SPF) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SRV) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SRV) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SSHFP) Header() *RR_Header      { return &rr.Hdr }
 func (rr *SSHFP) Header() *RR_Header      { return &rr.Hdr }
+func (rr *SVCB) Header() *RR_Header       { return &rr.Hdr }
 func (rr *TA) Header() *RR_Header         { return &rr.Hdr }
 func (rr *TA) Header() *RR_Header         { return &rr.Hdr }
 func (rr *TALINK) Header() *RR_Header     { return &rr.Hdr }
 func (rr *TALINK) Header() *RR_Header     { return &rr.Hdr }
 func (rr *TKEY) Header() *RR_Header       { return &rr.Hdr }
 func (rr *TKEY) Header() *RR_Header       { return &rr.Hdr }
@@ -239,6 +247,7 @@ func (rr *UID) Header() *RR_Header        { return &rr.Hdr }
 func (rr *UINFO) Header() *RR_Header      { return &rr.Hdr }
 func (rr *UINFO) Header() *RR_Header      { return &rr.Hdr }
 func (rr *URI) Header() *RR_Header        { return &rr.Hdr }
 func (rr *URI) Header() *RR_Header        { return &rr.Hdr }
 func (rr *X25) Header() *RR_Header        { return &rr.Hdr }
 func (rr *X25) Header() *RR_Header        { return &rr.Hdr }
+func (rr *ZONEMD) Header() *RR_Header     { return &rr.Hdr }
 
 
 // len() functions
 // len() functions
 func (rr *A) len(off int, compression map[string]struct{}) int {
 func (rr *A) len(off int, compression map[string]struct{}) int {
@@ -592,6 +601,15 @@ func (rr *SSHFP) len(off int, compression map[string]struct{}) int {
 	l += len(rr.FingerPrint) / 2
 	l += len(rr.FingerPrint) / 2
 	return l
 	return l
 }
 }
+func (rr *SVCB) len(off int, compression map[string]struct{}) int {
+	l := rr.Hdr.len(off, compression)
+	l += 2 // Priority
+	l += domainNameLen(rr.Target, off+l, compression, false)
+	for _, x := range rr.Value {
+		l += 4 + int(x.len())
+	}
+	return l
+}
 func (rr *TA) len(off int, compression map[string]struct{}) int {
 func (rr *TA) len(off int, compression map[string]struct{}) int {
 	l := rr.Hdr.len(off, compression)
 	l := rr.Hdr.len(off, compression)
 	l += 2 // KeyTag
 	l += 2 // KeyTag
@@ -669,6 +687,14 @@ func (rr *X25) len(off int, compression map[string]struct{}) int {
 	l += len(rr.PSDNAddress) + 1
 	l += len(rr.PSDNAddress) + 1
 	return l
 	return l
 }
 }
+func (rr *ZONEMD) len(off int, compression map[string]struct{}) int {
+	l := rr.Hdr.len(off, compression)
+	l += 4 // Serial
+	l++    // Scheme
+	l++    // Hash
+	l += len(rr.Digest) / 2
+	return l
+}
 
 
 // copy() functions
 // copy() functions
 func (rr *A) copy() RR {
 func (rr *A) copy() RR {
@@ -685,8 +711,8 @@ func (rr *ANY) copy() RR {
 }
 }
 func (rr *APL) copy() RR {
 func (rr *APL) copy() RR {
 	Prefixes := make([]APLPrefix, len(rr.Prefixes))
 	Prefixes := make([]APLPrefix, len(rr.Prefixes))
-	for i := range rr.Prefixes {
-		Prefixes[i] = rr.Prefixes[i].copy()
+	for i, e := range rr.Prefixes {
+		Prefixes[i] = e.copy()
 	}
 	}
 	return &APL{rr.Hdr, Prefixes}
 	return &APL{rr.Hdr, Prefixes}
 }
 }
@@ -698,6 +724,12 @@ func (rr *AVC) copy() RR {
 func (rr *CAA) copy() RR {
 func (rr *CAA) copy() RR {
 	return &CAA{rr.Hdr, rr.Flag, rr.Tag, rr.Value}
 	return &CAA{rr.Hdr, rr.Flag, rr.Tag, rr.Value}
 }
 }
+func (rr *CDNSKEY) copy() RR {
+	return &CDNSKEY{*rr.DNSKEY.copy().(*DNSKEY)}
+}
+func (rr *CDS) copy() RR {
+	return &CDS{*rr.DS.copy().(*DS)}
+}
 func (rr *CERT) copy() RR {
 func (rr *CERT) copy() RR {
 	return &CERT{rr.Hdr, rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate}
 	return &CERT{rr.Hdr, rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate}
 }
 }
@@ -712,6 +744,9 @@ func (rr *CSYNC) copy() RR {
 func (rr *DHCID) copy() RR {
 func (rr *DHCID) copy() RR {
 	return &DHCID{rr.Hdr, rr.Digest}
 	return &DHCID{rr.Hdr, rr.Digest}
 }
 }
+func (rr *DLV) copy() RR {
+	return &DLV{*rr.DS.copy().(*DS)}
+}
 func (rr *DNAME) copy() RR {
 func (rr *DNAME) copy() RR {
 	return &DNAME{rr.Hdr, rr.Target}
 	return &DNAME{rr.Hdr, rr.Target}
 }
 }
@@ -744,6 +779,12 @@ func (rr *HIP) copy() RR {
 	copy(RendezvousServers, rr.RendezvousServers)
 	copy(RendezvousServers, rr.RendezvousServers)
 	return &HIP{rr.Hdr, rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers}
 	return &HIP{rr.Hdr, rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers}
 }
 }
+func (rr *HTTPS) copy() RR {
+	return &HTTPS{*rr.SVCB.copy().(*SVCB)}
+}
+func (rr *KEY) copy() RR {
+	return &KEY{*rr.DNSKEY.copy().(*DNSKEY)}
+}
 func (rr *KX) copy() RR {
 func (rr *KX) copy() RR {
 	return &KX{rr.Hdr, rr.Preference, rr.Exchanger}
 	return &KX{rr.Hdr, rr.Preference, rr.Exchanger}
 }
 }
@@ -847,6 +888,9 @@ func (rr *RRSIG) copy() RR {
 func (rr *RT) copy() RR {
 func (rr *RT) copy() RR {
 	return &RT{rr.Hdr, rr.Preference, rr.Host}
 	return &RT{rr.Hdr, rr.Preference, rr.Host}
 }
 }
+func (rr *SIG) copy() RR {
+	return &SIG{*rr.RRSIG.copy().(*RRSIG)}
+}
 func (rr *SMIMEA) copy() RR {
 func (rr *SMIMEA) copy() RR {
 	return &SMIMEA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
 	return &SMIMEA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
 }
 }
@@ -864,6 +908,13 @@ func (rr *SRV) copy() RR {
 func (rr *SSHFP) copy() RR {
 func (rr *SSHFP) copy() RR {
 	return &SSHFP{rr.Hdr, rr.Algorithm, rr.Type, rr.FingerPrint}
 	return &SSHFP{rr.Hdr, rr.Algorithm, rr.Type, rr.FingerPrint}
 }
 }
+func (rr *SVCB) copy() RR {
+	Value := make([]SVCBKeyValue, len(rr.Value))
+	for i, e := range rr.Value {
+		Value[i] = e.copy()
+	}
+	return &SVCB{rr.Hdr, rr.Priority, rr.Target, Value}
+}
 func (rr *TA) copy() RR {
 func (rr *TA) copy() RR {
 	return &TA{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
 	return &TA{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
 }
 }
@@ -896,3 +947,6 @@ func (rr *URI) copy() RR {
 func (rr *X25) copy() RR {
 func (rr *X25) copy() RR {
 	return &X25{rr.Hdr, rr.PSDNAddress}
 	return &X25{rr.Hdr, rr.PSDNAddress}
 }
 }
+func (rr *ZONEMD) copy() RR {
+	return &ZONEMD{rr.Hdr, rr.Serial, rr.Scheme, rr.Hash, rr.Digest}
+}

+ 20 - 0
vendor/github.com/refraction-networking/utls/README.md

@@ -102,6 +102,26 @@ you can set UConn.HandshakeStateBuilt = true, and marshal clientHello into UConn
 In this case you will be responsible for modifying other parts of Config and ClientHelloMsg to reflect your setup
 In this case you will be responsible for modifying other parts of Config and ClientHelloMsg to reflect your setup
 and not confuse "crypto/tls", which will be processing response from server.
 and not confuse "crypto/tls", which will be processing response from server.
 
 
+### Fingerprinting Captured Client Hello
+You can use a captured client hello to generate new ones that mimic/have the same properties as the original.
+The generated client hellos _should_ look like they were generated from the same client software as the original fingerprinted bytes.
+In order to do this:
+1) Create a `ClientHelloSpec` from the raw bytes of the original client hello
+2) Use `HelloCustom` as an argument for `UClient()` to get empty config
+3) Use `ApplyPreset` with the generated `ClientHelloSpec` to set the appropriate connection properties
+```
+uConn := UClient(&net.TCPConn{}, nil, HelloCustom)
+fingerprinter := &Fingerprinter{}
+generatedSpec, err := fingerprinter.FingerprintClientHello(rawCapturedClientHelloBytes)
+if err != nil {
+  panic("fingerprinting failed: %v", err)
+}
+if err := uConn.ApplyPreset(generatedSpec); err != nil {
+  panic("applying generated spec failed: %v", err)
+}
+```
+The `rawCapturedClientHelloBytes` should be the full tls record, including the record type/version/length header.
+
 ## Roller
 ## Roller
 
 
 A simple wrapper, that allows to easily use multiple latest(auto-updated) fingerprints.
 A simple wrapper, that allows to easily use multiple latest(auto-updated) fingerprints.

+ 5 - 1
vendor/github.com/refraction-networking/utls/key_agreement.go

@@ -69,7 +69,11 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}
 
 
-	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
+	rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
+	if !ok {
+		return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
+	}
+	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
 	if err != nil {
 	if err != nil {
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}

+ 16 - 1
vendor/github.com/refraction-networking/utls/ticket.go

@@ -171,7 +171,20 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
 	return encrypted, nil
 	return encrypted, nil
 }
 }
 
 
+// [uTLS] changed to use exported DecryptTicketWith func below
 func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
 func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
+	tks := ticketKeys(c.config.ticketKeys()).ToPublic()
+	return DecryptTicketWith(encrypted, tks)
+}
+
+// DecryptTicketWith decrypts an encrypted session ticket
+// using a TicketKeys (ie []TicketKey) struct
+//
+// usedOldKey will be true if the key used for decryption is
+// not the first in the []TicketKey slice
+//
+// [uTLS] changed to be made public and take a TicketKeys instead of use a Conn receiver
+func DecryptTicketWith(encrypted []byte, tks TicketKeys) (plaintext []byte, usedOldKey bool) {
 	if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
 	if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
 		return nil, false
 		return nil, false
 	}
 	}
@@ -181,7 +194,9 @@ func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey boo
 	macBytes := encrypted[len(encrypted)-sha256.Size:]
 	macBytes := encrypted[len(encrypted)-sha256.Size:]
 	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
 	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
 
 
-	keys := c.config.ticketKeys()
+	// keys := c.config.ticketKeys() // [uTLS] keys are received as a function argument
+
+	keys := tks.ToPrivate()
 	keyIndex := -1
 	keyIndex := -1
 	for i, candidateKey := range keys {
 	for i, candidateKey := range keys {
 		if bytes.Equal(keyName, candidateKey.keyName[:]) {
 		if bytes.Equal(keyName, candidateKey.keyName[:]) {

+ 15 - 0
vendor/github.com/refraction-networking/utls/u_common.go

@@ -39,6 +39,7 @@ const (
 
 
 	FAKE_TLS_DHE_RSA_WITH_AES_128_CBC_SHA  = uint16(0x0033)
 	FAKE_TLS_DHE_RSA_WITH_AES_128_CBC_SHA  = uint16(0x0033)
 	FAKE_TLS_DHE_RSA_WITH_AES_256_CBC_SHA  = uint16(0x0039)
 	FAKE_TLS_DHE_RSA_WITH_AES_256_CBC_SHA  = uint16(0x0039)
+	FAKE_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384  = uint16(0x009f)
 	FAKE_TLS_RSA_WITH_RC4_128_MD5          = uint16(0x0004)
 	FAKE_TLS_RSA_WITH_RC4_128_MD5          = uint16(0x0004)
 	FAKE_TLS_EMPTY_RENEGOTIATION_INFO_SCSV = uint16(0x00ff)
 	FAKE_TLS_EMPTY_RENEGOTIATION_INFO_SCSV = uint16(0x00ff)
 )
 )
@@ -161,6 +162,20 @@ var (
 // https://tools.ietf.org/html/draft-ietf-tls-grease-01
 // https://tools.ietf.org/html/draft-ietf-tls-grease-01
 const GREASE_PLACEHOLDER = 0x0a0a
 const GREASE_PLACEHOLDER = 0x0a0a
 
 
+func isGREASEUint16(v uint16) bool {
+	// First byte is same as second byte
+	// and lowest nibble is 0xa
+	return ((v >> 8) == v&0xff) && v&0xf == 0xa
+}
+
+func unGREASEUint16(v uint16) uint16 {
+	if isGREASEUint16(v) {
+		return GREASE_PLACEHOLDER
+	} else {
+		return v
+	}
+}
+
 // utlsMacSHA384 returns a SHA-384 based MAC. These are only supported in TLS 1.2
 // utlsMacSHA384 returns a SHA-384 based MAC. These are only supported in TLS 1.2
 // so the given version is ignored.
 // so the given version is ignored.
 func utlsMacSHA384(version uint16, key []byte) macFunction {
 func utlsMacSHA384(version uint16, key []byte) macFunction {

+ 360 - 0
vendor/github.com/refraction-networking/utls/u_fingerprinter.go

@@ -0,0 +1,360 @@
+// Copyright 2017 Google Inc. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+	"errors"
+	"fmt"
+	"strings"
+
+	"golang.org/x/crypto/cryptobyte"
+)
+
+// Fingerprinter is a struct largely for holding options for the FingerprintClientHello func
+type Fingerprinter struct {
+	// KeepPSK will ensure that the PreSharedKey extension is passed along into the resulting ClientHelloSpec as-is
+	KeepPSK bool
+	// AllowBluntMimicry will ensure that unknown extensions are
+	// passed along into the resulting ClientHelloSpec as-is
+	// It will not ensure that the PSK is passed along, if you require that, use KeepPSK
+	// WARNING: there could be numerous subtle issues with ClientHelloSpecs
+	// that are generated with this flag which could compromise security and/or mimicry
+	AllowBluntMimicry bool
+	// AlwaysAddPadding will always add a UtlsPaddingExtension with BoringPaddingStyle
+	// at the end of the extensions list if it isn't found in the fingerprinted hello.
+	// This could be useful in scenarios where the hello you are fingerprinting does not
+	// have any padding, but you suspect that other changes you make to the final hello
+	// (including things like different SNI lengths) would cause padding to be necessary
+	AlwaysAddPadding bool
+}
+
+// FingerprintClientHello returns a ClientHelloSpec which is based on the
+// ClientHello that is passed in as the data argument
+//
+// If the ClientHello passed in has extensions that are not recognized or cannot be handled
+// it will return a non-nil error and a nil *ClientHelloSpec value
+//
+// The data should be the full tls record, including the record type/version/length header
+// as well as the handshake type/length/version header
+// https://tools.ietf.org/html/rfc5246#section-6.2
+// https://tools.ietf.org/html/rfc5246#section-7.4
+func (f *Fingerprinter) FingerprintClientHello(data []byte) (*ClientHelloSpec, error) {
+	clientHelloSpec := &ClientHelloSpec{}
+	s := cryptobyte.String(data)
+
+	var contentType uint8
+	var recordVersion uint16
+	if !s.ReadUint8(&contentType) || // record type
+		!s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
+		return nil, errors.New("unable to read record type, version, and length")
+	}
+
+	if recordType(contentType) != recordTypeHandshake {
+		return nil, errors.New("record is not a handshake")
+	}
+
+	var handshakeVersion uint16
+	var handshakeType uint8
+
+	if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
+		!s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
+		return nil, errors.New("unable to read handshake message type, length, and random")
+	}
+
+	if handshakeType != typeClientHello {
+		return nil, errors.New("handshake message is not a ClientHello")
+	}
+
+	clientHelloSpec.TLSVersMin = recordVersion
+	clientHelloSpec.TLSVersMax = handshakeVersion
+
+	var ignoredSessionID cryptobyte.String
+	if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
+		return nil, errors.New("unable to read session id")
+	}
+
+	var cipherSuitesBytes cryptobyte.String
+	if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
+		return nil, errors.New("unable to read ciphersuites")
+	}
+	cipherSuites := []uint16{}
+	for !cipherSuitesBytes.Empty() {
+		var suite uint16
+		if !cipherSuitesBytes.ReadUint16(&suite) {
+			return nil, errors.New("unable to read ciphersuite")
+		}
+		cipherSuites = append(cipherSuites, unGREASEUint16(suite))
+	}
+	clientHelloSpec.CipherSuites = cipherSuites
+
+	if !readUint8LengthPrefixed(&s, &clientHelloSpec.CompressionMethods) {
+		return nil, errors.New("unable to read compression methods")
+	}
+
+	if s.Empty() {
+		// ClientHello is optionally followed by extension data
+		return clientHelloSpec, nil
+	}
+
+	var extensions cryptobyte.String
+	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
+		return nil, errors.New("unable to read extensions data")
+	}
+
+	for !extensions.Empty() {
+		var extension uint16
+		var extData cryptobyte.String
+		if !extensions.ReadUint16(&extension) ||
+			!extensions.ReadUint16LengthPrefixed(&extData) {
+			return nil, errors.New("unable to read extension data")
+		}
+
+		switch extension {
+		case extensionServerName:
+			// RFC 6066, Section 3
+			var nameList cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
+				return nil, errors.New("unable to read server name extension data")
+			}
+			var serverName string
+			for !nameList.Empty() {
+				var nameType uint8
+				var serverNameBytes cryptobyte.String
+				if !nameList.ReadUint8(&nameType) ||
+					!nameList.ReadUint16LengthPrefixed(&serverNameBytes) ||
+					serverNameBytes.Empty() {
+					return nil, errors.New("unable to read server name extension data")
+				}
+				if nameType != 0 {
+					continue
+				}
+				if len(serverName) != 0 {
+					return nil, errors.New("multiple names of the same name_type in server name extension are prohibited")
+				}
+				serverName = string(serverNameBytes)
+				if strings.HasSuffix(serverName, ".") {
+					return nil, errors.New("SNI value may not include a trailing dot")
+				}
+
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SNIExtension{})
+
+			}
+		case extensionNextProtoNeg:
+			// draft-agl-tls-nextprotoneg-04
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &NPNExtension{})
+
+		case extensionStatusRequest:
+			// RFC 4366, Section 3.6
+			var statusType uint8
+			var ignored cryptobyte.String
+			if !extData.ReadUint8(&statusType) ||
+				!extData.ReadUint16LengthPrefixed(&ignored) ||
+				!extData.ReadUint16LengthPrefixed(&ignored) {
+				return nil, errors.New("unable to read status request extension data")
+			}
+
+			if statusType == statusTypeOCSP {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &StatusRequestExtension{})
+			} else {
+				return nil, errors.New("status request extension statusType is not statusTypeOCSP")
+			}
+
+		case extensionSupportedCurves:
+			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
+			var curvesBytes cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&curvesBytes) || curvesBytes.Empty() {
+				return nil, errors.New("unable to read supported curves extension data")
+			}
+			curves := []CurveID{}
+			for !curvesBytes.Empty() {
+				var curve uint16
+				if !curvesBytes.ReadUint16(&curve) {
+					return nil, errors.New("unable to read supported curves extension data")
+				}
+				curves = append(curves, CurveID(unGREASEUint16(curve)))
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedCurvesExtension{curves})
+
+		case extensionSupportedPoints:
+			// RFC 4492, Section 5.1.2
+			supportedPoints := []uint8{}
+			if !readUint8LengthPrefixed(&extData, &supportedPoints) ||
+				len(supportedPoints) == 0 {
+				return nil, errors.New("unable to read supported points extension data")
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedPointsExtension{supportedPoints})
+
+		case extensionSessionTicket:
+			// RFC 5077, Section 3.2
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SessionTicketExtension{})
+
+		case extensionSignatureAlgorithms:
+			// RFC 5246, Section 7.4.1.4.1
+			var sigAndAlgs cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
+				return nil, errors.New("unable to read signature algorithms extension data")
+			}
+			supportedSignatureAlgorithms := []SignatureScheme{}
+			for !sigAndAlgs.Empty() {
+				var sigAndAlg uint16
+				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+					return nil, errors.New("unable to read signature algorithms extension data")
+				}
+				supportedSignatureAlgorithms = append(
+					supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SignatureAlgorithmsExtension{supportedSignatureAlgorithms})
+
+		case extensionSignatureAlgorithmsCert:
+			// RFC 8446, Section 4.2.3
+			if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension SignatureAlgorithmsCert")
+			}
+
+		case extensionRenegotiationInfo:
+			// RFC 5746, Section 3.2
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &RenegotiationInfoExtension{RenegotiateOnceAsClient})
+
+		case extensionALPN:
+			// RFC 7301, Section 3.1
+			var protoList cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
+				return nil, errors.New("unable to read ALPN extension data")
+			}
+			alpnProtocols := []string{}
+			for !protoList.Empty() {
+				var proto cryptobyte.String
+				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
+					return nil, errors.New("unable to read ALPN extension data")
+				}
+				alpnProtocols = append(alpnProtocols, string(proto))
+
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &ALPNExtension{alpnProtocols})
+
+		case extensionSCT:
+			// RFC 6962, Section 3.3.1
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SCTExtension{})
+
+		case extensionSupportedVersions:
+			// RFC 8446, Section 4.2.1
+			var versList cryptobyte.String
+			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
+				return nil, errors.New("unable to read supported versions extension data")
+			}
+			supportedVersions := []uint16{}
+			for !versList.Empty() {
+				var vers uint16
+				if !versList.ReadUint16(&vers) {
+					return nil, errors.New("unable to read supported versions extension data")
+				}
+				supportedVersions = append(supportedVersions, unGREASEUint16(vers))
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedVersionsExtension{supportedVersions})
+			// If SupportedVersionsExtension is present, use that instead of record+handshake versions
+			clientHelloSpec.TLSVersMin = 0
+			clientHelloSpec.TLSVersMax = 0
+
+		case extensionKeyShare:
+			// RFC 8446, Section 4.2.8
+			var clientShares cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&clientShares) {
+				return nil, errors.New("unable to read key share extension data")
+			}
+			keyShares := []KeyShare{}
+			for !clientShares.Empty() {
+				var ks KeyShare
+				var group uint16
+				if !clientShares.ReadUint16(&group) ||
+					!readUint16LengthPrefixed(&clientShares, &ks.Data) ||
+					len(ks.Data) == 0 {
+					return nil, errors.New("unable to read key share extension data")
+				}
+				ks.Group = CurveID(unGREASEUint16(group))
+				// if not GREASE, key share data will be discarded as it should
+				// be generated per connection
+				if ks.Group != GREASE_PLACEHOLDER {
+					ks.Data = nil
+				}
+				keyShares = append(keyShares, ks)
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &KeyShareExtension{keyShares})
+
+		case extensionPSKModes:
+			// RFC 8446, Section 4.2.9
+			// TODO: PSK Modes have their own form of GREASE-ing which is not currently implemented
+			// the current functionality will NOT re-GREASE/re-randomize these values when using a fingerprinted spec
+			// https://github.com/refraction-networking/utls/pull/58#discussion_r522354105
+			// https://tools.ietf.org/html/draft-ietf-tls-grease-01#section-2
+			pskModes := []uint8{}
+			if !readUint8LengthPrefixed(&extData, &pskModes) {
+				return nil, errors.New("unable to read PSK extension data")
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &PSKKeyExchangeModesExtension{pskModes})
+
+		case utlsExtensionExtendedMasterSecret:
+			// https://tools.ietf.org/html/rfc7627
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsExtendedMasterSecretExtension{})
+
+		case utlsExtensionPadding:
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
+
+		case fakeExtensionChannelID, fakeCertCompressionAlgs, fakeRecordSizeLimit:
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+
+		case extensionPreSharedKey:
+			// RFC 8446, Section 4.2.11
+			if f.KeepPSK {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension PreSharedKey")
+			}
+
+		case extensionCookie:
+			// RFC 8446, Section 4.2.2
+			if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension Cookie")
+			}
+
+		case extensionEarlyData:
+			// RFC 8446, Section 4.2.10
+			if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension EarlyData")
+			}
+
+		default:
+			if isGREASEUint16(extension) {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsGREASEExtension{unGREASEUint16(extension), extData})
+			} else if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, fmt.Errorf("unsupported extension %#x", extension)
+			}
+
+			continue
+		}
+	}
+
+	if f.AlwaysAddPadding {
+		alreadyHasPadding := false
+		for _, ext := range clientHelloSpec.Extensions {
+			if _, ok := ext.(*UtlsPaddingExtension); ok {
+				alreadyHasPadding = true
+				break
+			}
+		}
+		if !alreadyHasPadding {
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
+		}
+	}
+
+	return clientHelloSpec, nil
+}

+ 10 - 0
vendor/github.com/refraction-networking/utls/u_parrots.go

@@ -617,6 +617,9 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
 	uconn.Extensions = make([]TLSExtension, len(p.Extensions))
 	uconn.Extensions = make([]TLSExtension, len(p.Extensions))
 	copy(uconn.Extensions, p.Extensions)
 	copy(uconn.Extensions, p.Extensions)
 
 
+	// Check whether NPN extension actually exists
+	var haveNPN bool
+
 	// reGrease, and point things to each other
 	// reGrease, and point things to each other
 	for _, e := range uconn.Extensions {
 	for _, e := range uconn.Extensions {
 		switch ext := e.(type) {
 		switch ext := e.(type) {
@@ -681,8 +684,15 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
 					ext.Versions[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_version)
 					ext.Versions[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_version)
 				}
 				}
 			}
 			}
+		case *NPNExtension:
+			haveNPN = true
 		}
 		}
 	}
 	}
+
+	// The default golang behavior in makeClientHello always sets NextProtoNeg if NextProtos is set,
+	// but NextProtos is also used by ALPN and our spec nmay not actually have a NPN extension
+	hello.NextProtoNeg = haveNPN
+
 	return nil
 	return nil
 }
 }
 
 

+ 59 - 0
vendor/github.com/refraction-networking/utls/u_public.go

@@ -421,6 +421,16 @@ func (chm *clientHelloMsg) getPublicPtr() *ClientHelloMsg {
 	}
 	}
 }
 }
 
 
+// UnmarshalClientHello allows external code to parse raw client hellos.
+// It returns nil on failure.
+func UnmarshalClientHello(data []byte) *ClientHelloMsg {
+	m := &clientHelloMsg{}
+	if m.unmarshal(data) {
+		return m.getPublicPtr()
+	}
+	return nil
+}
+
 // A CipherSuite is a specific combination of key agreement, cipher and MAC
 // A CipherSuite is a specific combination of key agreement, cipher and MAC
 // function. All cipher suites currently assume RSA key agreement.
 // function. All cipher suites currently assume RSA key agreement.
 type CipherSuite struct {
 type CipherSuite struct {
@@ -612,3 +622,52 @@ func (css *ClientSessionState) SetServerCertificates(ServerCertificates []*x509.
 func (css *ClientSessionState) SetVerifiedChains(VerifiedChains [][]*x509.Certificate) {
 func (css *ClientSessionState) SetVerifiedChains(VerifiedChains [][]*x509.Certificate) {
 	css.verifiedChains = VerifiedChains
 	css.verifiedChains = VerifiedChains
 }
 }
+
+// TicketKey is the internal representation of a session ticket key.
+type TicketKey struct {
+	// KeyName is an opaque byte string that serves to identify the session
+	// ticket key. It's exposed as plaintext in every session ticket.
+	KeyName [ticketKeyNameLen]byte
+	AesKey  [16]byte
+	HmacKey [16]byte
+}
+
+type TicketKeys []TicketKey
+type ticketKeys []ticketKey
+
+func TicketKeyFromBytes(b [32]byte) TicketKey {
+	tk := ticketKeyFromBytes(b)
+	return tk.ToPublic()
+}
+
+func (tk ticketKey) ToPublic() TicketKey {
+	return TicketKey{
+		KeyName: tk.keyName,
+		AesKey:  tk.aesKey,
+		HmacKey: tk.hmacKey,
+	}
+}
+
+func (TK TicketKey) ToPrivate() ticketKey {
+	return ticketKey{
+		keyName: TK.KeyName,
+		aesKey:  TK.AesKey,
+		hmacKey: TK.HmacKey,
+	}
+}
+
+func (tks ticketKeys) ToPublic() []TicketKey {
+	var TKS []TicketKey
+	for _, ks := range tks {
+		TKS = append(TKS, ks.ToPublic())
+	}
+	return TKS
+}
+
+func (TKS TicketKeys) ToPrivate() []ticketKey {
+	var tks []ticketKey
+	for _, TK := range TKS {
+		tks = append(tks, TK.ToPrivate())
+	}
+	return tks
+}

+ 6 - 6
vendor/vendor.json

@@ -457,10 +457,10 @@
 			"revisionTime": "2020-07-28T10:15:04Z"
 			"revisionTime": "2020-07-28T10:15:04Z"
 		},
 		},
 		{
 		{
-			"checksumSHA1": "CqKUytzURTAPk1mUOsySyXGQft0=",
+			"checksumSHA1": "zpjq6xHytHc3aAYfkaomcDmeAek=",
 			"path": "github.com/miekg/dns",
 			"path": "github.com/miekg/dns",
-			"revision": "d128d10d176b810543f8fecd089082a29d3f159d",
-			"revisionTime": "2020-04-28T07:24:18Z"
+			"revision": "ab67aa64230094bdd0167ee5360e00e0a250a3ac",
+			"revisionTime": "2021-08-04T16:16:52Z"
 		},
 		},
 		{
 		{
 			"checksumSHA1": "m2L8ohfZiFRsMW3iynaH/TWgnSY=",
 			"checksumSHA1": "m2L8ohfZiFRsMW3iynaH/TWgnSY=",
@@ -608,10 +608,10 @@
 			"revisionTime": "2021-06-04T20:39:09Z"
 			"revisionTime": "2021-06-04T20:39:09Z"
 		},
 		},
 		{
 		{
-			"checksumSHA1": "OagdWaWcbCBQZR5bBGgGaK3nddE=",
+			"checksumSHA1": "jOxlnqvKSKn1SIkA5ldRe5lxqAc=",
 			"path": "github.com/refraction-networking/utls",
 			"path": "github.com/refraction-networking/utls",
-			"revision": "186025ac7b77465439618d1aeb2a5e444714d1cc",
-			"revisionTime": "2020-07-29T01:25:36Z"
+			"revision": "0b2885c8c0d4467cfe98136748a9d011d0b8fff0",
+			"revisionTime": "2021-07-13T16:56:36Z"
 		},
 		},
 		{
 		{
 			"checksumSHA1": "Fn9JW8u40ABN9Uc9wuvquuyOB+8=",
 			"checksumSHA1": "Fn9JW8u40ABN9Uc9wuvquuyOB+8=",