Ver código fonte

Merge branch 'master' into quic-enhancements

Rod Hynes 4 anos atrás
pai
commit
f5d39d5185
89 arquivos alterados com 4081 adições e 1012 exclusões
  1. 134 0
      .github/workflows/tests.yml
  2. 0 77
      .travis.yml
  3. 1 1
      MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml
  4. 2 1
      README.md
  5. 9 0
      psiphon/common/osl/osl.go
  6. 49 0
      psiphon/common/parameters/parameters.go
  7. 5 0
      psiphon/common/parameters/parameters_test.go
  8. 41 0
      psiphon/common/parameters/portlist.go
  9. 196 0
      psiphon/common/portlist.go
  10. 222 0
      psiphon/common/portlist_test.go
  11. 82 12
      psiphon/common/protocol/serverEntry.go
  12. 18 4
      psiphon/config.go
  13. 34 25
      psiphon/controller.go
  14. BIN
      psiphon/controller_test.config.enc
  15. 6 1
      psiphon/dataStore.go
  16. 29 40
      psiphon/dialParameters.go
  17. 116 11
      psiphon/dialParameters_test.go
  18. BIN
      psiphon/feedback_test.config.enc
  19. 6 7
      psiphon/feedback_test.go
  20. 12 0
      psiphon/meekConn.go
  21. 3 0
      psiphon/net.go
  22. 16 8
      psiphon/notice.go
  23. 6 30
      psiphon/server/listener.go
  24. 2 36
      psiphon/server/listener_test.go
  25. 66 24
      psiphon/server/meek.go
  26. 75 7
      psiphon/server/meek_test.go
  27. 1 1
      psiphon/server/psinet/psinet.go
  28. 76 6
      psiphon/server/server_test.go
  29. 18 93
      psiphon/server/trafficRules.go
  30. 40 18
      psiphon/server/tunnelServer.go
  31. 105 54
      psiphon/server/udp.go
  32. 5 0
      psiphon/upstreamproxy/go-ntlm/README.md
  33. 23 5
      psiphon/upstreamproxy/go-ntlm/ntlm/av_pairs.go
  34. 37 5
      psiphon/upstreamproxy/go-ntlm/ntlm/challenge_responses.go
  35. 38 4
      psiphon/upstreamproxy/go-ntlm/ntlm/message_authenticate.go
  36. 23 1
      psiphon/upstreamproxy/go-ntlm/ntlm/message_challenge.go
  37. 6 1
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1.go
  38. 5 5
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1_test.go
  39. 6 2
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2.go
  40. 1 1
      psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2_test.go
  41. 14 0
      psiphon/upstreamproxy/go-ntlm/ntlm/payload.go
  42. 7 0
      psiphon/upstreamproxy/go-ntlm/ntlm/version.go
  43. 7 1
      vendor/github.com/Psiphon-Labs/tls-tris/key_agreement.go
  44. 1 1
      vendor/github.com/miekg/dns/Makefile.release
  45. 14 6
      vendor/github.com/miekg/dns/README.md
  46. 1 0
      vendor/github.com/miekg/dns/acceptfunc.go
  47. 110 45
      vendor/github.com/miekg/dns/client.go
  48. 1 4
      vendor/github.com/miekg/dns/defaults.go
  49. 27 3
      vendor/github.com/miekg/dns/dns.go
  50. 18 47
      vendor/github.com/miekg/dns/dnssec.go
  51. 3 4
      vendor/github.com/miekg/dns/dnssec_keygen.go
  52. 4 17
      vendor/github.com/miekg/dns/dnssec_keyscan.go
  53. 3 20
      vendor/github.com/miekg/dns/dnssec_privkey.go
  54. 29 5
      vendor/github.com/miekg/dns/doc.go
  55. 12 4
      vendor/github.com/miekg/dns/duplicate_generate.go
  56. 151 5
      vendor/github.com/miekg/dns/edns.go
  57. 12 12
      vendor/github.com/miekg/dns/generate.go
  58. 4 6
      vendor/github.com/miekg/dns/go.mod
  59. 9 38
      vendor/github.com/miekg/dns/go.sum
  60. 1 1
      vendor/github.com/miekg/dns/labels.go
  61. 0 0
      vendor/github.com/miekg/dns/listen_no_reuseport.go
  62. 0 0
      vendor/github.com/miekg/dns/listen_reuseport.go
  63. 15 13
      vendor/github.com/miekg/dns/msg.go
  64. 7 0
      vendor/github.com/miekg/dns/msg_generate.go
  65. 73 82
      vendor/github.com/miekg/dns/msg_helpers.go
  66. 11 5
      vendor/github.com/miekg/dns/msg_truncate.go
  67. 2 2
      vendor/github.com/miekg/dns/privaterr.go
  68. 59 22
      vendor/github.com/miekg/dns/scan.go
  69. 53 18
      vendor/github.com/miekg/dns/scan_rr.go
  70. 2 2
      vendor/github.com/miekg/dns/serve_mux.go
  71. 92 28
      vendor/github.com/miekg/dns/server.go
  72. 3 15
      vendor/github.com/miekg/dns/sig0.go
  73. 755 0
      vendor/github.com/miekg/dns/svcb.go
  74. 99 59
      vendor/github.com/miekg/dns/tsig.go
  75. 79 50
      vendor/github.com/miekg/dns/types.go
  76. 21 5
      vendor/github.com/miekg/dns/types_generate.go
  77. 3 1
      vendor/github.com/miekg/dns/update.go
  78. 1 1
      vendor/github.com/miekg/dns/version.go
  79. 183 0
      vendor/github.com/miekg/dns/zduplicate.go
  80. 134 0
      vendor/github.com/miekg/dns/zmsg.go
  81. 56 2
      vendor/github.com/miekg/dns/ztypes.go
  82. 20 0
      vendor/github.com/refraction-networking/utls/README.md
  83. 5 1
      vendor/github.com/refraction-networking/utls/key_agreement.go
  84. 16 1
      vendor/github.com/refraction-networking/utls/ticket.go
  85. 15 0
      vendor/github.com/refraction-networking/utls/u_common.go
  86. 360 0
      vendor/github.com/refraction-networking/utls/u_fingerprinter.go
  87. 10 0
      vendor/github.com/refraction-networking/utls/u_parrots.go
  88. 59 0
      vendor/github.com/refraction-networking/utls/u_public.go
  89. 6 6
      vendor/vendor.json

+ 134 - 0
.github/workflows/tests.yml

@@ -0,0 +1,134 @@
+name: CI
+
+on:
+  workflow_dispatch:
+  push:
+    branches:
+      - master
+      - staging-client
+      - staging-server
+
+jobs:
+  run_tests:
+
+    strategy:
+      fail-fast: false
+      matrix:
+        os: [ "ubuntu" ]
+        go: [ "1.14.12" ]
+        test-type: [ "detector", "coverage", "memory" ]
+
+    runs-on: ${{ matrix.os }}-latest
+
+    name: psiphon-tunnel-core ${{ matrix.test-type }} tests on ${{ matrix.os}}, Go ${{ matrix.go }}
+
+    permissions:
+      checks: write
+      contents: read
+
+    env:
+      GOPATH: ${{ github.workspace }}/go
+
+    steps:
+
+      - name: Clone repository
+        uses: actions/checkout@v2
+        with:
+          path: ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+
+      - name: Install Go
+        uses: actions/setup-go@v2
+        with:
+          go-version: ${{ matrix.go }}
+
+      - name: Install networking components
+        run: |
+          sudo apt-get update
+          sudo apt-get install libnetfilter-queue-dev
+          sudo apt-get install conntrack
+
+      - name: Install coverage tools
+        if: ${{ matrix.test-type == 'coverage' }}
+        run: |
+          go get github.com/axw/gocov/gocov
+          go get github.com/modocache/gover
+          go get github.com/mattn/goveralls
+          go get golang.org/x/tools/cmd/cover
+
+      - name: Check environment
+        run: |
+          echo "GitHub workspace: $GITHUB_WORKSPACE"
+          echo "Working directory: `pwd`"
+          echo "GOROOT: $GOROOT"
+          echo "GOPATH: $GOPATH"
+          echo "Go version: `go version`"
+
+      - name: Pave config files
+        env:
+          CONTROLLER_TEST_CONFIG: ${{ secrets.CONTROLLER_TEST_CONFIG }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          echo "$CONTROLLER_TEST_CONFIG" > ./psiphon/controller_test.config
+
+      # TODO: fix and re-enable test
+      # sudo -E env "PATH=$PATH" go test -v -race ./psiphon/common/tun
+      - name: Run tests with data race detector
+        if: ${{ matrix.test-type == 'detector' }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          go test -v -race ./psiphon/common
+          go test -v -race ./psiphon/common/accesscontrol
+          go test -v -race ./psiphon/common/crypto/ssh
+          go test -v -race ./psiphon/common/fragmentor
+          go test -v -race ./psiphon/common/obfuscator
+          go test -v -race ./psiphon/common/osl
+          sudo -E env "PATH=$PATH" go test -v -race -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/common/packetman
+          go test -v -race ./psiphon/common/parameters
+          go test -v -race ./psiphon/common/protocol
+          go test -v -race ./psiphon/common/quic
+          go test -v -race ./psiphon/common/tactics
+          go test -v -race ./psiphon/common/values
+          go test -v -race ./psiphon/common/wildcard
+          go test -v -race ./psiphon/transferstats
+          sudo -E env "PATH=$PATH" go test -v -race -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/server
+          go test -v -race ./psiphon/server/psinet
+          go test -v -race ./psiphon
+          go test -v -race ./ClientLibrary/clientlib
+          go test -v -race ./Server/logging/analysis
+
+      # TODO: fix and re-enable test
+      # sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./psiphon/common/tun
+      - name: Run tests with coverage
+        env:
+          COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+        if: ${{ matrix.test-type == 'coverage' && github.repository == 'Psiphon-Labs/psiphon-tunnel-core' }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          go test -v -covermode=count -coverprofile=common.coverprofile ./psiphon/common
+          go test -v -covermode=count -coverprofile=accesscontrol.coverprofile ./psiphon/common/accesscontrol
+          go test -v -covermode=count -coverprofile=ssh.coverprofile ./psiphon/common/crypto/ssh
+          go test -v -covermode=count -coverprofile=fragmentor.coverprofile ./psiphon/common/fragmentor
+          go test -v -covermode=count -coverprofile=obfuscator.coverprofile ./psiphon/common/obfuscator
+          go test -v -covermode=count -coverprofile=osl.coverprofile ./psiphon/common/osl
+          sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=packetman.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/common/packetman
+          go test -v -covermode=count -coverprofile=parameters.coverprofile ./psiphon/common/parameters
+          go test -v -covermode=count -coverprofile=protocol.coverprofile ./psiphon/common/protocol
+          go test -v -covermode=count -coverprofile=quic.coverprofile ./psiphon/common/quic
+          go test -v -covermode=count -coverprofile=tactics.coverprofile ./psiphon/common/tactics
+          go test -v -covermode=count -coverprofile=values.coverprofile ./psiphon/common/values
+          go test -v -covermode=count -coverprofile=wildcard.coverprofile ./psiphon/common/wildcard
+          go test -v -covermode=count -coverprofile=transferstats.coverprofile ./psiphon/transferstats
+          sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=server.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./psiphon/server
+          go test -v -covermode=count -coverprofile=psinet.coverprofile ./psiphon/server/psinet
+          go test -v -covermode=count -coverprofile=psiphon.coverprofile ./psiphon
+          go test -v -covermode=count -coverprofile=clientlib.coverprofile ./ClientLibrary/clientlib
+          go test -v -covermode=count -coverprofile=analysis.coverprofile ./Server/logging/analysis
+          $GOPATH/bin/gover
+          $GOPATH/bin/goveralls -coverprofile=gover.coverprofile -service=github -repotoken "$COVERALLS_TOKEN"
+
+      - name: Run memory tests
+        if: ${{ matrix.test-type == 'memory' }}
+        run: |
+          cd ${{ github.workspace }}/go/src/github.com/Psiphon-Labs/psiphon-tunnel-core
+          go test -v ./psiphon/memory_test -run TestReconnectTunnel
+          go test -v ./psiphon/memory_test -run TestRestartController

+ 0 - 77
.travis.yml

@@ -1,77 +0,0 @@
-dist: trusty
-language: go
-sudo: required
-go:
-- 1.14.12
-addons:
-  apt_packages:
-    - libx11-dev
-    - libgles2-mesa-dev
-script:
-- cd psiphon
-- go test -race -v ./common
-- go test -race -v ./common/accesscontrol
-- go test -race -v ./common/crypto/ssh
-- go test -race -v ./common/fragmentor
-- go test -race -v ./common/obfuscator
-- go test -race -v ./common/osl
-- sudo -E env "PATH=$PATH" go test -race -v -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./common/packetman
-- go test -race -v ./common/parameters
-- go test -race -v ./common/protocol
-- go test -race -v ./common/quic
-- go test -race -v ./common/tactics
-# TODO: fix and reenable test, which is failing in TravisCI environment:
-# --- FAIL: TestTunneledTCPIPv4
-#    tun_test.go:226: startTestTCPClient failed: syscall.Connect failed: connection timed out
-#
-#- sudo -E env "PATH=$PATH" go test -race -v ./common/tun
-- go test -race -v ./common/values
-- go test -race -v ./common/wildcard
-- go test -race -v ./transferstats
-- sudo -E env "PATH=$PATH" go test -race -v -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./server
-- go test -race -v ./server/psinet
-- go test -race -v ../Server/logging/analysis
-- go test -race -v ../ClientLibrary/clientlib
-- go test -race -v
-- go test -v -covermode=count -coverprofile=common.coverprofile ./common
-- go test -v -covermode=count -coverprofile=accesscontrol.coverprofile ./common/accesscontrol
-- go test -v -covermode=count -coverprofile=ssh.coverprofile ./common/crypto/ssh
-- go test -v -covermode=count -coverprofile=fragmentor.coverprofile ./common/fragmentor
-- go test -v -covermode=count -coverprofile=obfuscator.coverprofile ./common/obfuscator
-- go test -v -covermode=count -coverprofile=osl.coverprofile ./common/osl
-- go test -v -covermode=count -coverprofile=parameters.coverprofile ./common/parameters
-- sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=packetman.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./common/packetman
-- go test -v -covermode=count -coverprofile=protocol.coverprofile ./common/protocol
-- go test -v -covermode=count -coverprofile=quic.coverprofile ./common/quic
-- go test -v -covermode=count -coverprofile=tactics.coverprofile ./common/tactics
-# TODO: see "tun" test comment above
-#- sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./common/tun
-- go test -v -covermode=count -coverprofile=values.coverprofile ./common/values
-- go test -v -covermode=count -coverprofile=wildcard.coverprofile ./common/wildcard
-- go test -v -covermode=count -coverprofile=transferstats.coverprofile ./transferstats
-- sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=server.coverprofile -tags "PSIPHON_RUN_PACKET_MANIPULATOR_TEST" ./server
-- go test -v -covermode=count -coverprofile=psinet.coverprofile ./server/psinet
-- go test -v -covermode=count -coverprofile=analysis.coverprofile ../Server/logging/analysis
-- go test -v -covermode=count -coverprofile=clientlib.coverprofile ../ClientLibrary/clientlib
-- go test -v -covermode=count -coverprofile=psiphon.coverprofile
-- go test -v ./memory_test -run TestReconnectTunnel
-- go test -v ./memory_test -run TestRestartController
-after_script:
-- $HOME/gopath/bin/gover
-- $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci -repotoken $COVERALLS_TOKEN
-before_install:
-- go get github.com/axw/gocov/gocov
-- go get github.com/modocache/gover
-- go get github.com/mattn/goveralls
-- if ! go get github.com/golang/tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi
-- git rev-parse --short HEAD > psiphon/git_rev
-- openssl aes-256-cbc -K $encrypted_bf83b4ab4874_key -iv $encrypted_bf83b4ab4874_iv
-  -in psiphon/controller_test.config.enc -out psiphon/controller_test.config -d
-- openssl aes-256-cbc -K $encrypted_560fd0d04977_key -iv $encrypted_560fd0d04977_iv 
-  -in psiphon/feedback_test.config.enc -out psiphon/feedback_test.config -d
-notifications:
-  slack:
-    rooms:
-      secure: jVo/BZ1iFtg4g5V+eNxETwXPnbhwVwGzN1vkHJnCLAhV/md3/uHGsZQIMfitqgrX/T+9JBVRbRezjBwfJHYLs40IJTCWt167Lz8R1NlazLyEpcGcdesG05cTl9oEcBb7X52kZt7r8ZIBwdB7W6U/E0/i41qKamiEJqISMsdOoFA=
-    on_success: always
-    on_failure: always

+ 1 - 1
MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml

@@ -1,4 +1,4 @@
 <?xml version="1.0" encoding="utf-8"?>
 <full-backup-content>
-    <include domain="file" path="ca.psiphon.PsiphonTunnel.tunnel-core" />
+    <exclude domain="file" path="ca.psiphon.PsiphonTunnel.tunnel-core" />
 </full-backup-content>

+ 2 - 1
README.md

@@ -1,4 +1,5 @@
-[![Build Status](https://travis-ci.org/Psiphon-Labs/psiphon-tunnel-core.png)](https://travis-ci.org/Psiphon-Labs/psiphon-tunnel-core) [![Coverage Status](https://coveralls.io/repos/github/Psiphon-Labs/psiphon-tunnel-core/badge.svg?branch=master)](https://coveralls.io/github/Psiphon-Labs/psiphon-tunnel-core?branch=master)
+[![CI](https://github.com/Psiphon-Labs/psiphon-tunnel-core/actions/workflows/tests.yml/badge.svg)](https://github.com/Psiphon-Labs/psiphon-tunnel-core/actions/workflows/tests.yml) [![Coverage Status](https://coveralls.io/repos/github/Psiphon-Labs/psiphon-tunnel-core/badge.svg?branch=master)](https://coveralls.io/github/Psiphon-Labs/psiphon-tunnel-core?branch=master)
+
 
 Psiphon Tunnel Core README
 ================================================================================

+ 9 - 0
psiphon/common/osl/osl.go

@@ -294,6 +294,10 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 
 	for _, scheme := range config.Schemes {
 
+		if scheme == nil {
+			return nil, errors.TraceNew("invalid scheme")
+		}
+
 		epoch, err := time.Parse(time.RFC3339, scheme.Epoch)
 		if err != nil {
 			return nil, errors.Tracef("invalid epoch format: %s", err)
@@ -322,6 +326,11 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		}
 
 		for index, seedSpec := range scheme.SeedSpecs {
+
+			if seedSpec == nil {
+				return nil, errors.TraceNew("invalid seed spec")
+			}
+
 			if len(seedSpec.ID) != KEY_LENGTH_BYTES {
 				return nil, errors.TraceNew("invalid seed spec ID")
 			}

+ 49 - 0
psiphon/common/parameters/parameters.go

@@ -99,6 +99,8 @@ const (
 	InitialLimitTunnelProtocolsCandidateCount        = "InitialLimitTunnelProtocolsCandidateCount"
 	LimitTunnelProtocolsProbability                  = "LimitTunnelProtocolsProbability"
 	LimitTunnelProtocols                             = "LimitTunnelProtocols"
+	LimitTunnelDialPortNumbersProbability            = "LimitTunnelDialPortNumbersProbability"
+	LimitTunnelDialPortNumbers                       = "LimitTunnelDialPortNumbers"
 	LimitTLSProfilesProbability                      = "LimitTLSProfilesProbability"
 	LimitTLSProfiles                                 = "LimitTLSProfiles"
 	UseOnlyCustomTLSProfiles                         = "UseOnlyCustomTLSProfiles"
@@ -362,6 +364,9 @@ var defaultParameters = map[string]struct {
 	LimitTunnelProtocolsProbability: {value: 1.0, minimum: 0.0},
 	LimitTunnelProtocols:            {value: protocol.TunnelProtocols{}},
 
+	LimitTunnelDialPortNumbersProbability: {value: 1.0, minimum: 0.0},
+	LimitTunnelDialPortNumbers:            {value: TunnelProtocolPortLists{}},
+
 	LimitTLSProfilesProbability:           {value: 1.0, minimum: 0.0},
 	LimitTLSProfiles:                      {value: protocol.TLSProfiles{}},
 	UseOnlyCustomTLSProfiles:              {value: false},
@@ -931,6 +936,22 @@ func (p *Parameters) Set(
 					}
 					return nil, errors.Trace(err)
 				}
+			case FrontingSpecs:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
+			case TunnelProtocolPortLists:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			}
 
 			// Enforce any minimums. Assumes defaultParameters[name]
@@ -1389,3 +1410,31 @@ func (p ParametersAccessor) FrontingSpecs(name string) FrontingSpecs {
 	p.snapshot.getValue(name, &value)
 	return value
 }
+
+// TunnelProtocolPortLists returns a TunnelProtocolPortLists parameter value.
+func (p ParametersAccessor) TunnelProtocolPortLists(name string) TunnelProtocolPortLists {
+
+	probabilityName := name + "Probability"
+	_, ok := p.snapshot.parameters[probabilityName]
+	if ok {
+		probabilityValue := float64(1.0)
+		p.snapshot.getValue(probabilityName, &probabilityValue)
+		if !prng.FlipWeightedCoin(probabilityValue) {
+			defaultParameter, ok := defaultParameters[name]
+			if ok {
+				defaultValue, ok := defaultParameter.value.(TunnelProtocolPortLists)
+				if ok {
+					value := make(TunnelProtocolPortLists)
+					for tunnelProtocol, portLists := range defaultValue {
+						value[tunnelProtocol] = portLists
+					}
+					return value
+				}
+			}
+		}
+	}
+
+	value := make(TunnelProtocolPortLists)
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 5 - 0
psiphon/common/parameters/parameters_test.go

@@ -154,6 +154,11 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("FrontingSpecs returned %+v expected %+v", g, v)
 			}
+		case TunnelProtocolPortLists:
+			g := p.Get().TunnelProtocolPortLists(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("TunnelProtocolPortLists returned %+v expected %+v", g, v)
+			}
 		default:
 			t.Fatalf("Unhandled default type: %s", name)
 		}

+ 41 - 0
psiphon/common/parameters/portlist.go

@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+import (
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+)
+
+// TunnelProtocolPortLists is a map from tunnel protocol names (or "All") to a
+// list of port number ranges.
+type TunnelProtocolPortLists map[string]*common.PortList
+
+// Validate checks that tunnel protocol names are valid.
+func (lists TunnelProtocolPortLists) Validate() error {
+	for tunnelProtocol, _ := range lists {
+		if tunnelProtocol != protocol.TUNNEL_PROTOCOLS_ALL &&
+			!common.Contains(protocol.SupportedTunnelProtocols, tunnelProtocol) {
+			return errors.TraceNew("invalid tunnel protocol for port list")
+		}
+	}
+	return nil
+}

+ 196 - 0
psiphon/common/portlist.go

@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"bytes"
+	"encoding/json"
+	"strconv"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+)
+
+// PortList provides a lookup for a configured list of IP ports and port
+// ranges. PortList is intended for use with JSON config files and is
+// initialized via UnmarshalJSON.
+//
+// A JSON port list field should look like:
+//
+// "FieldName": [1, 2, 3, [10, 20], [30, 40]]
+//
+// where the ports in the list are 1, 2, 3, 10-20, 30-40. UnmarshalJSON
+// validates that each port is in the range 1-65535 and that ranges have two
+// elements in increasing order. PortList is designed to be backwards
+// compatible with existing JSON config files where port list fields were
+// defined as `[]int`.
+type PortList struct {
+	portRanges [][2]int
+	lookup     map[int]bool
+}
+
+const lookupThreshold = 10
+
+// OptimizeLookups converts the internal port list representation to use a
+// map, which increases the performance of lookups for longer lists with an
+// increased memory footprint tradeoff. OptimizeLookups is not safe to use
+// concurrently with Lookup and should be called immediately after
+// UnmarshalJSON and before performing lookups.
+func (p *PortList) OptimizeLookups() {
+	if p == nil {
+		return
+	}
+	// TODO: does the threshold take long ranges into account?
+	if len(p.portRanges) > lookupThreshold {
+		p.lookup = make(map[int]bool)
+		for _, portRange := range p.portRanges {
+			for i := portRange[0]; i <= portRange[1]; i++ {
+				p.lookup[i] = true
+			}
+		}
+	}
+}
+
+// IsEmpty returns true for a nil PortList or a PortList with no entries.
+func (p *PortList) IsEmpty() bool {
+	if p == nil {
+		return true
+	}
+	return len(p.portRanges) == 0
+}
+
+// Lookup returns true if the specified port is in the port list and false
+// otherwise. Lookups on a nil PortList are allowed and return false.
+func (p *PortList) Lookup(port int) bool {
+	if p == nil {
+		return false
+	}
+	if p.lookup != nil {
+		return p.lookup[port]
+	}
+	for _, portRange := range p.portRanges {
+		if port >= portRange[0] && port <= portRange[1] {
+			return true
+		}
+	}
+	return false
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface.
+func (p *PortList) UnmarshalJSON(b []byte) error {
+
+	p.portRanges = nil
+	p.lookup = nil
+
+	if bytes.Equal(b, []byte("null")) {
+		return nil
+	}
+
+	decoder := json.NewDecoder(bytes.NewReader(b))
+	decoder.UseNumber()
+
+	var array []interface{}
+
+	err := decoder.Decode(&array)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	p.portRanges = make([][2]int, len(array))
+
+	for i, portRange := range array {
+
+		var startPort, endPort int64
+
+		if portNumber, ok := portRange.(json.Number); ok {
+
+			port, err := portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+
+			startPort = port
+			endPort = port
+
+		} else if array, ok := portRange.([]interface{}); ok {
+
+			if len(array) != 2 {
+				return errors.TraceNew("invalid range size")
+			}
+
+			portNumber, ok := array[0].(json.Number)
+			if !ok {
+				return errors.TraceNew("invalid type")
+			}
+			port, err := portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+			startPort = port
+
+			portNumber, ok = array[1].(json.Number)
+			if !ok {
+				return errors.TraceNew("invalid type")
+			}
+			port, err = portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+			endPort = port
+
+		} else {
+
+			return errors.TraceNew("invalid type")
+		}
+
+		if startPort < 1 || startPort > 65535 {
+			return errors.TraceNew("invalid range start")
+		}
+
+		if endPort < 1 || endPort > 65535 || endPort < startPort {
+			return errors.TraceNew("invalid range end")
+		}
+
+		p.portRanges[i] = [2]int{int(startPort), int(endPort)}
+	}
+
+	return nil
+}
+
+// MarshalJSON implements the json.Marshaler interface.
+func (p *PortList) MarshalJSON() ([]byte, error) {
+	var json bytes.Buffer
+	json.WriteString("[")
+	for i, portRange := range p.portRanges {
+		if i > 0 {
+			json.WriteString(",")
+		}
+		if portRange[0] == portRange[1] {
+			json.WriteString(strconv.Itoa(portRange[0]))
+		} else {
+			json.WriteString("[")
+			json.WriteString(strconv.Itoa(portRange[0]))
+			json.WriteString(",")
+			json.WriteString(strconv.Itoa(portRange[1]))
+			json.WriteString("]")
+		}
+	}
+	json.WriteString("]")
+	return json.Bytes(), nil
+}

+ 222 - 0
psiphon/common/portlist_test.go

@@ -0,0 +1,222 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"encoding/json"
+	"strings"
+	"testing"
+	"unicode"
+)
+
+func TestPortList(t *testing.T) {
+
+	var p *PortList
+
+	err := json.Unmarshal([]byte("[1.5]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of float port number")
+	}
+
+	err = json.Unmarshal([]byte("[-1]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of negative port number")
+	}
+
+	err = json.Unmarshal([]byte("[0]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port number")
+	}
+
+	err = json.Unmarshal([]byte("[65536]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port number")
+	}
+
+	err = json.Unmarshal([]byte("[[2,1]]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port range")
+	}
+
+	p = nil
+
+	if p.Lookup(1) != false {
+		t.Fatalf("unexpected nil PortList Lookup result")
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected nil PortList IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[1]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	s := struct {
+		List1 *PortList
+		List2 *PortList
+	}{}
+
+	jsonString := `
+    {
+        "List1" : [1,2,[10,20],100,[1000,2000]],
+        "List2" : [3,4,5,[300,400],1000,2000,[3000,3996],3997,3998,3999,4000]
+    }
+    `
+
+	err = json.Unmarshal([]byte(jsonString), &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	// Marshal and re-Unmarshal to exercise PortList.MarshalJSON.
+
+	jsonBytes, err := json.Marshal(s)
+	if err != nil {
+		t.Fatalf("Marshal failed: %v", err)
+	}
+
+	strip := func(s string) string {
+		return strings.Map(func(r rune) rune {
+			if unicode.IsSpace(r) {
+				return -1
+			}
+			return r
+		}, s)
+	}
+
+	if strip(jsonString) != strip(string(jsonBytes)) {
+
+		t.Fatalf("unexpected JSON encoding")
+	}
+
+	err = json.Unmarshal(jsonBytes, &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	s.List1.OptimizeLookups()
+	if s.List1.lookup != nil {
+		t.Fatalf("unexpected lookup initialization")
+	}
+
+	s.List2.OptimizeLookups()
+	if s.List2.lookup == nil {
+		t.Fatalf("unexpected lookup initialization")
+	}
+
+	for port := 0; port < 65536; port++ {
+
+		lookup1 := s.List1.Lookup(port)
+		expected1 := port == 1 ||
+			port == 2 ||
+			(port >= 10 && port <= 20) ||
+			port == 100 ||
+			(port >= 1000 && port <= 2000)
+		if lookup1 != expected1 {
+			t.Fatalf("unexpected port lookup: %d %v", port, lookup1)
+		}
+
+		lookup2 := s.List2.Lookup(port)
+		expected2 := port == 3 ||
+			port == 4 ||
+			port == 5 ||
+			(port >= 300 && port <= 400) ||
+			port == 1000 || port == 2000 ||
+			(port >= 3000 && port <= 4000)
+		if lookup2 != expected2 {
+			t.Fatalf("unexpected port lookup: %d %v", port, lookup2)
+		}
+	}
+}
+
+func BenchmarkPortListLinear(b *testing.B) {
+
+	s := struct {
+		List PortList
+	}{}
+
+	jsonStruct := `
+    {
+        "List" : [1,2,3,4,5,6,7,8,9,[10,20]]
+    }
+    `
+
+	err := json.Unmarshal([]byte(jsonStruct), &s)
+	if err != nil {
+		b.Fatalf("Unmarshal failed: %v", err)
+	}
+	s.List.OptimizeLookups()
+	if s.List.lookup != nil {
+		b.Fatalf("unexpected lookup initialization")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		for port := 0; port < 65536; port++ {
+			s.List.Lookup(port)
+		}
+	}
+}
+
+func BenchmarkPortListMap(b *testing.B) {
+
+	s := struct {
+		List PortList
+	}{}
+
+	jsonStruct := `
+    {
+        "List" : [1,2,3,4,5,6,7,8,9,10,[11,20]]
+    }
+    `
+
+	err := json.Unmarshal([]byte(jsonStruct), &s)
+	if err != nil {
+		b.Fatalf("Unmarshal failed: %v", err)
+	}
+	s.List.OptimizeLookups()
+	if s.List.lookup == nil {
+		b.Fatalf("unexpected lookup initialization")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		for port := 0; port < 65536; port++ {
+			s.List.Lookup(port)
+		}
+	}
+}

+ 82 - 12
psiphon/common/protocol/serverEntry.go

@@ -490,51 +490,121 @@ type ConditionallyEnabledComponents interface {
 	RefractionNetworkingEnabled() bool
 }
 
-// GetSupportedProtocols returns a list of tunnel protocols supported
-// by the ServerEntry's capabilities.
+// TunnelProtocolPortLists is a map from tunnel protocol names (or "All") to a
+// list of port number ranges.
+type TunnelProtocolPortLists map[string]*common.PortList
+
+// GetSupportedProtocols returns a list of tunnel protocols supported by the
+// ServerEntry's capabilities and allowed by various constraints.
 func (serverEntry *ServerEntry) GetSupportedProtocols(
 	conditionallyEnabled ConditionallyEnabledComponents,
 	useUpstreamProxy bool,
 	limitTunnelProtocols []string,
+	limitTunnelDialPortNumbers TunnelProtocolPortLists,
 	excludeIntensive bool) []string {
 
 	supportedProtocols := make([]string, 0)
 
-	for _, protocol := range SupportedTunnelProtocols {
+	for _, tunnelProtocol := range SupportedTunnelProtocols {
 
-		if useUpstreamProxy && !TunnelProtocolSupportsUpstreamProxy(protocol) {
+		if useUpstreamProxy && !TunnelProtocolSupportsUpstreamProxy(tunnelProtocol) {
 			continue
 		}
 
 		if len(limitTunnelProtocols) > 0 {
-			if !common.Contains(limitTunnelProtocols, protocol) {
+			if !common.Contains(limitTunnelProtocols, tunnelProtocol) {
 				continue
 			}
 		} else {
-			if common.Contains(DefaultDisabledTunnelProtocols, protocol) {
+			if common.Contains(DefaultDisabledTunnelProtocols, tunnelProtocol) {
 				continue
 			}
 		}
 
-		if excludeIntensive && TunnelProtocolIsResourceIntensive(protocol) {
+		if excludeIntensive && TunnelProtocolIsResourceIntensive(tunnelProtocol) {
 			continue
 		}
 
-		if (TunnelProtocolUsesQUIC(protocol) && !conditionallyEnabled.QUICEnabled()) ||
-			(TunnelProtocolUsesMarionette(protocol) && !conditionallyEnabled.MarionetteEnabled()) ||
-			(TunnelProtocolUsesRefractionNetworking(protocol) &&
+		if (TunnelProtocolUsesQUIC(tunnelProtocol) && !conditionallyEnabled.QUICEnabled()) ||
+			(TunnelProtocolUsesMarionette(tunnelProtocol) && !conditionallyEnabled.MarionetteEnabled()) ||
+			(TunnelProtocolUsesRefractionNetworking(tunnelProtocol) &&
 				!conditionallyEnabled.RefractionNetworkingEnabled()) {
 			continue
 		}
 
-		if serverEntry.SupportsProtocol(protocol) {
-			supportedProtocols = append(supportedProtocols, protocol)
+		if !serverEntry.SupportsProtocol(tunnelProtocol) {
+			continue
+		}
+
+		dialPortNumber, err := serverEntry.GetDialPortNumber(tunnelProtocol)
+		if err != nil {
+			continue
+		}
+
+		if len(limitTunnelDialPortNumbers) > 0 {
+			if portList, ok := limitTunnelDialPortNumbers[tunnelProtocol]; ok {
+				if !portList.Lookup(dialPortNumber) {
+					continue
+				}
+			} else if portList, ok := limitTunnelDialPortNumbers[TUNNEL_PROTOCOLS_ALL]; ok {
+				if !portList.Lookup(dialPortNumber) {
+					continue
+				}
+			}
 		}
 
+		supportedProtocols = append(supportedProtocols, tunnelProtocol)
+
 	}
 	return supportedProtocols
 }
 
+func (serverEntry *ServerEntry) GetDialPortNumber(tunnelProtocol string) (int, error) {
+
+	if !serverEntry.SupportsProtocol(tunnelProtocol) {
+		return 0, errors.TraceNew("protocol not supported")
+	}
+
+	switch tunnelProtocol {
+
+	case TUNNEL_PROTOCOL_SSH:
+		return serverEntry.SshPort, nil
+
+	case TUNNEL_PROTOCOL_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedPort, nil
+
+	case TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedTapDancePort, nil
+
+	case TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedConjurePort, nil
+
+	case TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedQUICPort, nil
+
+	case TUNNEL_PROTOCOL_FRONTED_MEEK,
+		TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
+		return 443, nil
+
+	case TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
+		return 80, nil
+
+	case TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+		TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET,
+		TUNNEL_PROTOCOL_UNFRONTED_MEEK:
+		return serverEntry.MeekServerPort, nil
+
+	case TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
+		// The port is encoded in the marionnete "format"
+		// Limitations:
+		// - not compatible with LimitDialPortNumbers
+		// - accurate port is not reported via dial_port_number
+		return -1, nil
+	}
+
+	return 0, errors.TraceNew("unknown protocol")
+}
+
 // GetSupportedTacticsProtocols returns a list of tunnel protocols,
 // supported by the ServerEntry's capabilities, that may be used
 // for tactics requests.

+ 18 - 4
psiphon/config.go

@@ -751,6 +751,9 @@ type Config struct {
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	UpstreamProxyAllowAllServerEntrySources *bool
 
+	// LimitTunnelDialPortNumbers is for testing purposes.
+	LimitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
+
 	// params is the active parameters.Parameters with defaults, config values,
 	// and, optionally, tactics applied.
 	//
@@ -1604,7 +1607,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.UseOnlyCustomTLSProfiles] = *config.UseOnlyCustomTLSProfiles
 	}
 
-	if config.CustomTLSProfiles != nil {
+	if len(config.CustomTLSProfiles) > 0 {
 		applyParameters[parameters.CustomTLSProfiles] = config.CustomTLSProfiles
 	}
 
@@ -1616,7 +1619,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.NoDefaultTLSSessionIDProbability] = *config.NoDefaultTLSSessionIDProbability
 	}
 
-	if config.DisableFrontingProviderTLSProfiles != nil {
+	if len(config.DisableFrontingProviderTLSProfiles) > 0 {
 		applyParameters[parameters.DisableFrontingProviderTLSProfiles] = config.DisableFrontingProviderTLSProfiles
 	}
 
@@ -1660,7 +1663,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.ConjureAPIRegistrarURL] = config.ConjureAPIRegistrarURL
 	}
 
-	if config.ConjureAPIRegistrarFrontingSpecs != nil {
+	if len(config.ConjureAPIRegistrarFrontingSpecs) > 0 {
 		applyParameters[parameters.ConjureAPIRegistrarFrontingSpecs] = config.ConjureAPIRegistrarFrontingSpecs
 	}
 
@@ -1720,6 +1723,10 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 	}
 
+	if len(config.LimitTunnelDialPortNumbers) > 0 {
+		applyParameters[parameters.LimitTunnelDialPortNumbers] = config.LimitTunnelDialPortNumbers
+	}
+
 	// When adding new config dial parameters that may override tactics, also
 	// update setDialParametersHash.
 
@@ -1929,7 +1936,7 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.NoDefaultTLSSessionIDProbability)
 	}
 
-	if config.DisableFrontingProviderTLSProfiles != nil {
+	if len(config.DisableFrontingProviderTLSProfiles) > 0 {
 		hash.Write([]byte("DisableFrontingProviderTLSProfiles"))
 		encodedDisableFrontingProviderTLSProfiles, _ :=
 			json.Marshal(config.DisableFrontingProviderTLSProfiles)
@@ -2044,6 +2051,13 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)
 	}
 
+	if len(config.LimitTunnelDialPortNumbers) > 0 {
+		hash.Write([]byte("LimitTunnelDialPortNumbers"))
+		encodedLimitTunnelDialPortNumbers, _ :=
+			json.Marshal(config.LimitTunnelDialPortNumbers)
+		hash.Write(encodedLimitTunnelDialPortNumbers)
+	}
+
 	config.dialParametersHash = hash.Sum(nil)
 }
 

+ 34 - 25
psiphon/controller.go

@@ -1358,15 +1358,16 @@ func (controller *Controller) triggerFetches() {
 }
 
 type protocolSelectionConstraints struct {
-	useUpstreamProxy                    bool
-	initialLimitProtocols               protocol.TunnelProtocols
-	initialLimitProtocolsCandidateCount int
-	limitProtocols                      protocol.TunnelProtocols
-	replayCandidateCount                int
+	useUpstreamProxy                          bool
+	initialLimitTunnelProtocols               protocol.TunnelProtocols
+	initialLimitTunnelProtocolsCandidateCount int
+	limitTunnelProtocols                      protocol.TunnelProtocols
+	limitTunnelDialPortNumbers                protocol.TunnelProtocolPortLists
+	replayCandidateCount                      int
 }
 
 func (p *protocolSelectionConstraints) hasInitialProtocols() bool {
-	return len(p.initialLimitProtocols) > 0 && p.initialLimitProtocolsCandidateCount > 0
+	return len(p.initialLimitTunnelProtocols) > 0 && p.initialLimitTunnelProtocolsCandidateCount > 0
 }
 
 func (p *protocolSelectionConstraints) isInitialCandidate(
@@ -1377,7 +1378,8 @@ func (p *protocolSelectionConstraints) isInitialCandidate(
 		len(serverEntry.GetSupportedProtocols(
 			conditionallyEnabledComponents{},
 			p.useUpstreamProxy,
-			p.initialLimitProtocols,
+			p.initialLimitTunnelProtocols,
+			p.limitTunnelDialPortNumbers,
 			excludeIntensive)) > 0
 }
 
@@ -1385,12 +1387,12 @@ func (p *protocolSelectionConstraints) isCandidate(
 	excludeIntensive bool,
 	serverEntry *protocol.ServerEntry) bool {
 
-	return len(p.limitProtocols) == 0 ||
-		len(serverEntry.GetSupportedProtocols(
-			conditionallyEnabledComponents{},
-			p.useUpstreamProxy,
-			p.limitProtocols,
-			excludeIntensive)) > 0
+	return len(serverEntry.GetSupportedProtocols(
+		conditionallyEnabledComponents{},
+		p.useUpstreamProxy,
+		p.limitTunnelProtocols,
+		p.limitTunnelDialPortNumbers,
+		excludeIntensive)) > 0
 }
 
 func (p *protocolSelectionConstraints) canReplay(
@@ -1413,16 +1415,19 @@ func (p *protocolSelectionConstraints) supportedProtocols(
 	excludeIntensive bool,
 	serverEntry *protocol.ServerEntry) []string {
 
-	limitProtocols := p.limitProtocols
+	limitTunnelProtocols := p.limitTunnelProtocols
+
+	if len(p.initialLimitTunnelProtocols) > 0 &&
+		p.initialLimitTunnelProtocolsCandidateCount > connectTunnelCount {
 
-	if len(p.initialLimitProtocols) > 0 && p.initialLimitProtocolsCandidateCount > connectTunnelCount {
-		limitProtocols = p.initialLimitProtocols
+		limitTunnelProtocols = p.initialLimitTunnelProtocols
 	}
 
 	return serverEntry.GetSupportedProtocols(
 		conditionallyEnabledComponents{},
 		p.useUpstreamProxy,
-		limitProtocols,
+		limitTunnelProtocols,
+		p.limitTunnelDialPortNumbers,
 		excludeIntensive)
 }
 
@@ -1578,11 +1583,15 @@ func (controller *Controller) launchEstablishing() {
 	p := controller.config.GetParameters().Get()
 
 	controller.protocolSelectionConstraints = &protocolSelectionConstraints{
-		useUpstreamProxy:                    controller.config.UseUpstreamProxy(),
-		initialLimitProtocols:               p.TunnelProtocols(parameters.InitialLimitTunnelProtocols),
-		initialLimitProtocolsCandidateCount: p.Int(parameters.InitialLimitTunnelProtocolsCandidateCount),
-		limitProtocols:                      p.TunnelProtocols(parameters.LimitTunnelProtocols),
-		replayCandidateCount:                p.Int(parameters.ReplayCandidateCount),
+		useUpstreamProxy:                          controller.config.UseUpstreamProxy(),
+		initialLimitTunnelProtocols:               p.TunnelProtocols(parameters.InitialLimitTunnelProtocols),
+		initialLimitTunnelProtocolsCandidateCount: p.Int(parameters.InitialLimitTunnelProtocolsCandidateCount),
+		limitTunnelProtocols:                      p.TunnelProtocols(parameters.LimitTunnelProtocols),
+
+		limitTunnelDialPortNumbers: protocol.TunnelProtocolPortLists(
+			p.TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers)),
+
+		replayCandidateCount: p.Int(parameters.ReplayCandidateCount),
 	}
 
 	// ConnectionWorkerPoolSize may be set by tactics.
@@ -1626,7 +1635,7 @@ func (controller *Controller) launchEstablishing() {
 	// proceeding.
 
 	awaitResponse := tunnelPoolSize > 1 ||
-		controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount > 0
+		controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount > 0
 
 	// AvailableEgressRegions: after a fresh install, the outer client may not
 	// have a list of regions to display; and LimitTunnelProtocols may reduce the
@@ -1720,11 +1729,11 @@ func (controller *Controller) launchEstablishing() {
 		// protocols may have some bad effect, such as a firewall blocking all
 		// traffic from a host.
 
-		if controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount > 0 {
+		if controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount > 0 {
 
 			if reportResponse.initialCandidatesAnyEgressRegion == 0 {
 				NoticeWarning("skipping initial limit tunnel protocols")
-				controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount = 0
+				controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount = 0
 
 				// Since we were unable to satisfy the InitialLimitTunnelProtocols
 				// tactic, trigger RSL, OSL, and upgrade fetches to potentially

BIN
psiphon/controller_test.config.enc


+ 6 - 1
psiphon/dataStore.go

@@ -671,7 +671,11 @@ func newTargetServerEntryIterator(config *Config, isTactics bool) (bool, *Server
 			return false, nil, errors.TraceNew("TargetServerEntry does not support EgressRegion")
 		}
 
-		limitTunnelProtocols := config.GetParameters().Get().TunnelProtocols(parameters.LimitTunnelProtocols)
+		p := config.GetParameters().Get()
+		limitTunnelProtocols := p.TunnelProtocols(parameters.LimitTunnelProtocols)
+		limitTunnelDialPortNumbers := protocol.TunnelProtocolPortLists(
+			p.TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers))
+
 		if len(limitTunnelProtocols) > 0 {
 			// At the ServerEntryIterator level, only limitTunnelProtocols is applied;
 			// excludeIntensive is handled higher up.
@@ -679,6 +683,7 @@ func newTargetServerEntryIterator(config *Config, isTactics bool) (bool, *Server
 				conditionallyEnabledComponents{},
 				config.UseUpstreamProxy(),
 				limitTunnelProtocols,
+				limitTunnelDialPortNumbers,
 				false)) == 0 {
 				return false, nil, errors.Tracef(
 					"TargetServerEntry does not support LimitTunnelProtocols: %v", limitTunnelProtocols)

+ 29 - 40
psiphon/dialParameters.go

@@ -26,6 +26,7 @@ import (
 	"fmt"
 	"net"
 	"net/http"
+	"strconv"
 	"strings"
 	"sync/atomic"
 	"time"
@@ -716,40 +717,27 @@ func MakeDialParameters(
 	// Set dial address fields. This portion of configuration is
 	// deterministic, given the parameters established or replayed so far.
 
-	switch dialParams.TunnelProtocol {
+	dialPortNumber, err := serverEntry.GetDialPortNumber(dialParams.TunnelProtocol)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
-	case protocol.TUNNEL_PROTOCOL_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort)
+	dialParams.DialPortNumber = strconv.Itoa(dialPortNumber)
 
-	case protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
+	switch dialParams.TunnelProtocol {
 
-	case protocol.TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedTapDancePort)
+	case protocol.TUNNEL_PROTOCOL_SSH,
+		protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
 
-	case protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedConjurePort)
+		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 
-	case protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedQUICPort)
+	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK,
+		protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
 
-	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
-		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
-		if serverEntry.MeekFrontingDisableSNI {
-			dialParams.MeekSNIServerName = ""
-			// When SNI is omitted, the transformed host name is not used.
-			dialParams.MeekTransformedHostName = false
-		} else if !dialParams.MeekTransformedHostName {
-			dialParams.MeekSNIServerName = dialParams.MeekFrontingDialAddress
-		}
-
-	case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
-		// Note: port comes from marionnete "format"
-		dialParams.DirectDialAddress = serverEntry.IpAddress
-
-	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		if serverEntry.MeekFrontingDisableSNI {
 			dialParams.MeekSNIServerName = ""
@@ -760,15 +748,17 @@ func MakeDialParameters(
 		}
 
 	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:80", dialParams.MeekFrontingDialAddress)
+
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		// For FRONTED HTTP, the Host header cannot be transformed.
 		dialParams.MeekTransformedHostName = false
 
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
+
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 		if !dialParams.MeekTransformedHostName {
-			if serverEntry.MeekServerPort == 80 {
+			if dialPortNumber == 80 {
 				dialParams.MeekHostHeader = serverEntry.IpAddress
 			} else {
 				dialParams.MeekHostHeader = dialParams.MeekDialAddress
@@ -778,17 +768,22 @@ func MakeDialParameters(
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 		protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET:
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 		if !dialParams.MeekTransformedHostName {
 			// Note: IP address in SNI field will be omitted.
 			dialParams.MeekSNIServerName = serverEntry.IpAddress
 		}
-		if serverEntry.MeekServerPort == 443 {
+		if dialPortNumber == 443 {
 			dialParams.MeekHostHeader = serverEntry.IpAddress
 		} else {
 			dialParams.MeekHostHeader = dialParams.MeekDialAddress
 		}
 
+	case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
+
+		// Note: port comes from marionette "format"
+		dialParams.DirectDialAddress = serverEntry.IpAddress
+
 	default:
 		return nil, errors.Tracef(
 			"unknown tunnel protocol: %s", dialParams.TunnelProtocol)
@@ -797,7 +792,7 @@ func MakeDialParameters(
 
 	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) {
 
-		host, port, _ := net.SplitHostPort(dialParams.MeekDialAddress)
+		host, _, _ := net.SplitHostPort(dialParams.MeekDialAddress)
 
 		if p.Bool(parameters.MeekDialDomainsOnly) {
 			if net.ParseIP(host) != nil {
@@ -806,17 +801,11 @@ func MakeDialParameters(
 			}
 		}
 
-		dialParams.DialPortNumber = port
-
 		// The underlying TLS will automatically disable SNI for IP address server name
 		// values; we have this explicit check here so we record the correct value for stats.
 		if net.ParseIP(dialParams.MeekSNIServerName) != nil {
 			dialParams.MeekSNIServerName = ""
 		}
-
-	} else {
-
-		_, dialParams.DialPortNumber, _ = net.SplitHostPort(dialParams.DirectDialAddress)
 	}
 
 	// Initialize/replay User-Agent header for HTTP upstream proxy and meek protocols.

+ 116 - 11
psiphon/dialParameters_test.go

@@ -87,7 +87,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
-	err = clientConfig.SetParameters("tag1", true, applyParameters)
+	err = clientConfig.SetParameters("tag1", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -346,7 +346,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	// Test: no replay after change tactics
 
 	applyParameters[parameters.ReplayDialParametersTTL] = "1s"
-	err = clientConfig.SetParameters("tag2", true, applyParameters)
+	err = clientConfig.SetParameters("tag2", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -400,7 +400,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.ReplayObfuscatedQUIC] = false
 	applyParameters[parameters.ReplayLivenessTest] = false
 	applyParameters[parameters.ReplayAPIRequestPadding] = false
-	err = clientConfig.SetParameters("tag3", true, applyParameters)
+	err = clientConfig.SetParameters("tag3", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -442,7 +442,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	applyParameters[parameters.RestrictFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 1.0
-	err = clientConfig.SetParameters("tag4", true, applyParameters)
+	err = clientConfig.SetParameters("tag4", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -462,7 +462,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 0.0
-	err = clientConfig.SetParameters("tag5", true, applyParameters)
+	err = clientConfig.SetParameters("tag5", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -558,6 +558,110 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 }
 
+func TestLimitTunnelDialPortNumbers(t *testing.T) {
+
+	testDataDirName, err := ioutil.TempDir("", "psiphon-limit-tunnel-dial-port-numbers-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %s", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	SetNoticeWriter(ioutil.Discard)
+
+	clientConfig := &Config{
+		PropagationChannelId: "0",
+		SponsorId:            "0",
+		DataRootDirectory:    testDataDirName,
+		NetworkIDGetter:      new(testNetworkGetter),
+	}
+
+	err = clientConfig.Commit(false)
+	if err != nil {
+		t.Fatalf("error committing configuration file: %s", err)
+	}
+
+	jsonLimitDialPortNumbers := `
+    {
+        "SSH" : [[10,11]],
+        "OSSH" : [[20,21]],
+        "QUIC-OSSH" : [[30,31]],
+        "TAPDANCE-OSSH" : [[40,41]],
+        "CONJURE-OSSH" : [[50,51]],
+        "All" : [[60,61],80,443]
+    }
+    `
+
+	var limitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
+	err = json.Unmarshal([]byte(jsonLimitDialPortNumbers), &limitTunnelDialPortNumbers)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	applyParameters := make(map[string]interface{})
+	applyParameters[parameters.LimitTunnelDialPortNumbers] = limitTunnelDialPortNumbers
+	applyParameters[parameters.LimitTunnelDialPortNumbersProbability] = 1.0
+	err = clientConfig.SetParameters("tag1", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	constraints := &protocolSelectionConstraints{
+		limitTunnelDialPortNumbers: protocol.TunnelProtocolPortLists(
+			clientConfig.GetParameters().Get().TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers)),
+	}
+
+	selectProtocol := func(serverEntry *protocol.ServerEntry) (string, bool) {
+		return constraints.selectProtocol(0, false, serverEntry)
+	}
+
+	for _, tunnelProtocol := range protocol.SupportedTunnelProtocols {
+
+		if common.Contains(protocol.DefaultDisabledTunnelProtocols, tunnelProtocol) {
+			continue
+		}
+
+		serverEntries := makeMockServerEntries(tunnelProtocol, "", 100)
+
+		selected := false
+		skipped := false
+
+		for _, serverEntry := range serverEntries {
+
+			selectedProtocol, ok := selectProtocol(serverEntry)
+
+			if ok {
+
+				if selectedProtocol != tunnelProtocol {
+					t.Fatalf("unexpected selected protocol: %s", selectedProtocol)
+				}
+
+				port, err := serverEntry.GetDialPortNumber(selectedProtocol)
+				if err != nil {
+					t.Fatalf("GetDialPortNumber failed: %s", err)
+				}
+
+				if port%10 != 0 && port%10 != 1 && !protocol.TunnelProtocolUsesFrontedMeek(selectedProtocol) {
+					t.Fatalf("unexpected dial port number: %d", port)
+				}
+
+				selected = true
+
+			} else {
+
+				skipped = true
+			}
+		}
+
+		if !selected {
+			t.Fatalf("expected at least one selected server entry: %s", tunnelProtocol)
+		}
+
+		if !skipped && !protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+			t.Fatalf("expected at least one skipped server entry: %s", tunnelProtocol)
+		}
+	}
+}
+
 func makeMockServerEntries(
 	tunnelProtocol string,
 	frontingProviderID string,
@@ -568,17 +672,18 @@ func makeMockServerEntries(
 	for i := 0; i < count; i++ {
 		serverEntries[i] = &protocol.ServerEntry{
 			IpAddress:                  fmt.Sprintf("192.168.0.%d", i),
-			SshPort:                    1,
-			SshObfuscatedPort:          2,
-			SshObfuscatedQUICPort:      3,
-			SshObfuscatedTapDancePort:  4,
-			SshObfuscatedConjurePort:   5,
-			MeekServerPort:             6,
+			SshPort:                    prng.Range(10, 19),
+			SshObfuscatedPort:          prng.Range(20, 29),
+			SshObfuscatedQUICPort:      prng.Range(30, 39),
+			SshObfuscatedTapDancePort:  prng.Range(40, 49),
+			SshObfuscatedConjurePort:   prng.Range(50, 59),
+			MeekServerPort:             prng.Range(60, 69),
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
 			FrontingProviderID:         frontingProviderID,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
+			Capabilities:               []string{protocol.GetCapability(tunnelProtocol)},
 		}
 	}
 

BIN
psiphon/feedback_test.config.enc


+ 6 - 7
psiphon/feedback_test.go

@@ -23,6 +23,7 @@ import (
 	"context"
 	"encoding/json"
 	"io/ioutil"
+	"os/exec"
 	"testing"
 )
 
@@ -41,7 +42,7 @@ type Diagnostics struct {
 }
 
 func TestFeedbackUpload(t *testing.T) {
-	configFileContents, err := ioutil.ReadFile("feedback_test.config")
+	configFileContents, err := ioutil.ReadFile("controller_test.config")
 	if err != nil {
 		// Skip, don't fail, if config file is not present
 		t.Skipf("error loading configuration file: %s", err)
@@ -65,13 +66,11 @@ func TestFeedbackUpload(t *testing.T) {
 		t.Fatalf("error committing configuration file: %s", err)
 	}
 
-	// git_rev is a file which contains the shortened hash of the latest commit
-	// pointed to by HEAD, i.e. git rev-parse --short HEAD.
-
-	shortRevHash, err := ioutil.ReadFile("git_rev")
+	shortRevHash, err := exec.Command("git", "rev-parse", "--short", "HEAD").Output()
 	if err != nil {
-		// Skip, don't fail, if git rev file is not present
-		t.Skipf("error loading git revision file: %s", err)
+		// Log, don't fail, if git rev is not available
+		t.Logf("error loading git revision file: %s", err)
+		shortRevHash = []byte("unknown")
 	}
 
 	// Construct feedback data which can be verified later

+ 12 - 0
psiphon/meekConn.go

@@ -645,6 +645,14 @@ func DialMeek(
 		meek.meekObfuscatedKey = meekConfig.MeekObfuscatedKey
 		meek.meekObfuscatorPaddingSeed = meekConfig.MeekObfuscatorPaddingSeed
 		meek.clientTunnelProtocol = meekConfig.ClientTunnelProtocol
+
+	} else if meek.mode == MeekModePlaintextRoundTrip {
+
+		// MeekModeRelay and MeekModeObfuscatedRoundTrip set the Host header
+		// implicitly via meek.url; MeekModePlaintextRoundTrip does not use
+		// meek.url; it uses the RoundTrip input request.URL instead. So the
+		// Host header is set to meekConfig.HostHeader explicitly here.
+		meek.additionalHeaders.Add("Host", meekConfig.HostHeader)
 	}
 
 	return meek, nil
@@ -842,6 +850,10 @@ func (meek *MeekConn) RoundTrip(request *http.Request) (*http.Response, error) {
 
 	requestCtx := request.Context()
 
+	// Clone the request to apply addtional headers without modifying the input.
+	request = request.Clone(requestCtx)
+	meek.addAdditionalHeaders(request)
+
 	// The setDialerRequestContext/CloseIdleConnections concurrency note in
 	// ObfuscatedRoundTrip applies to RoundTrip as well.
 

+ 3 - 0
psiphon/net.go

@@ -309,6 +309,9 @@ func ResolveIP(host string, conn net.Conn) (addrs []net.IP, ttls []time.Duration
 
 	// Process the response
 	response, err := dnsConn.ReadMsg()
+	if err == nil && response.MsgHdr.Id != query.MsgHdr.Id {
+		err = dns.ErrId
+	}
 	if err != nil {
 		return nil, nil, errors.Trace(err)
 	}

+ 16 - 8
psiphon/notice.go

@@ -425,9 +425,10 @@ func NoticeCandidateServers(
 	singletonNoticeLogger.outputNotice(
 		"CandidateServers", noticeIsDiagnostic,
 		"region", region,
-		"initialLimitTunnelProtocols", constraints.initialLimitProtocols,
-		"initialLimitTunnelProtocolsCandidateCount", constraints.initialLimitProtocolsCandidateCount,
-		"limitTunnelProtocols", constraints.limitProtocols,
+		"initialLimitTunnelProtocols", constraints.initialLimitTunnelProtocols,
+		"initialLimitTunnelProtocolsCandidateCount", constraints.initialLimitTunnelProtocolsCandidateCount,
+		"limitTunnelProtocols", constraints.limitTunnelProtocols,
+		"limitTunnelDialPortNumbers", constraints.limitTunnelDialPortNumbers,
 		"replayCandidateCount", constraints.replayCandidateCount,
 		"initialCount", initialCount,
 		"count", count,
@@ -1034,14 +1035,21 @@ func GetNotice(notice []byte) (
 	var object noticeObject
 	err = json.Unmarshal(notice, &object)
 	if err != nil {
-		return "", nil, err
+		return "", nil, errors.Trace(err)
 	}
-	var objectPayload interface{}
-	err = json.Unmarshal(object.Data, &objectPayload)
+
+	var data interface{}
+	err = json.Unmarshal(object.Data, &data)
 	if err != nil {
-		return "", nil, err
+		return "", nil, errors.Trace(err)
+	}
+
+	dataValue, ok := data.(map[string]interface{})
+	if !ok {
+		return "", nil, errors.TraceNew("invalid data value")
 	}
-	return object.NoticeType, objectPayload.(map[string]interface{}), nil
+
+	return object.NoticeType, dataValue, nil
 }
 
 // NoticeReceiver consumes a notice input stream and invokes a callback function

+ 6 - 30
psiphon/server/listener.go

@@ -25,14 +25,15 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 )
 
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // of certain tactics parameters to accepted connections. Tactics filtering is
-// limited to GeoIP attributes as the client has not yet sent API paramaters.
+// limited to GeoIP attributes as the client has not yet sent API parameters.
+// GeoIP uses the immediate peer IP, and so TacticsListener is suitable only
+// for tactics that do not require the original client GeoIP when fronted.
 type TacticsListener struct {
 	net.Listener
 	support        *SupportServices
@@ -77,11 +78,14 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 		return nil, err
 	}
 
+	// Limitation: RemoteAddr is the immediate peer IP, which is not the original
+	// client IP in the case of fronting.
 	geoIPData := listener.geoIPLookup(
 		common.IPAddressFromAddr(conn.RemoteAddr()))
 
 	p, err := listener.support.ServerTacticsParametersCache.Get(geoIPData)
 	if err != nil {
+		conn.Close()
 		return nil, errors.Trace(err)
 	}
 
@@ -90,34 +94,6 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 		return conn, nil
 	}
 
-	// Disconnect immediately if the clients tactics restricts usage of the
-	// fronting provider ID. The probability may be used to influence usage of a
-	// given fronting provider; but when only that provider works for a given
-	// client, and the probability is less than 1.0, the client can retry until
-	// it gets a successful coin flip.
-	//
-	// Clients will also skip candidates with restricted fronting provider IDs.
-	// The client-side probability, RestrictFrontingProviderIDsClientProbability,
-	// is applied independently of the server-side coin flip here.
-	//
-	//
-	// At this stage, GeoIP tactics filters are active, but handshake API
-	// parameters are not.
-	//
-	// See the comment in server.LoadConfig regarding fronting provider ID
-	// limitations.
-
-	if protocol.TunnelProtocolUsesFrontedMeek(listener.tunnelProtocol) &&
-		common.Contains(
-			p.Strings(parameters.RestrictFrontingProviderIDs),
-			listener.support.Config.GetFrontingProviderID()) {
-		if p.WeightedCoinFlip(
-			parameters.RestrictFrontingProviderIDsServerProbability) {
-			conn.Close()
-			return nil, nil
-		}
-	}
-
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	//
 	// In the OSSH case, replay is always activated and it is seeded using the

+ 2 - 36
psiphon/server/listener_test.go

@@ -28,7 +28,6 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 )
@@ -37,8 +36,6 @@ func TestListener(t *testing.T) {
 
 	tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
 
-	frontingProviderID := prng.HexString(8)
-
 	tacticsConfigJSONFormat := `
     {
       "RequestPublicKey" : "%s",
@@ -65,19 +62,6 @@ func TestListener(t *testing.T) {
               "FragmentorDownstreamMaxWriteBytes" : 1
             }
           }
-        },
-        {
-          "Filter" : {
-            "Regions": ["R3"],
-            "ISPs": ["I3"],
-            "Cities": ["C3"]
-          },
-          "Tactics" : {
-            "Parameters" : {
-              "RestrictFrontingProviderIDs" : ["%s"],
-              "RestrictFrontingProviderIDsServerProbability" : 1.0
-            }
-          }
         }
       ]
     }
@@ -92,7 +76,7 @@ func TestListener(t *testing.T) {
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
-		tunnelProtocol, frontingProviderID)
+		tunnelProtocol)
 
 	tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
 
@@ -122,12 +106,6 @@ func TestListener(t *testing.T) {
 	listenerUnfragmentedGeoIPWrongCity := func(string) GeoIPData {
 		return GeoIPData{Country: "R1", ISP: "I1", City: "C2"}
 	}
-	listenerRestrictedFrontingProviderIDGeoIP := func(string) GeoIPData {
-		return GeoIPData{Country: "R3", ISP: "I3", City: "C3"}
-	}
-	listenerUnrestrictedFrontingProviderIDWrongRegion := func(string) GeoIPData {
-		return GeoIPData{Country: "R2", ISP: "I3", City: "C3"}
-	}
 
 	listenerTestCases := []struct {
 		description      string
@@ -159,18 +137,6 @@ func TestListener(t *testing.T) {
 			false,
 			true,
 		},
-		{
-			"restricted",
-			listenerRestrictedFrontingProviderIDGeoIP,
-			false,
-			false,
-		},
-		{
-			"unrestricted-region",
-			listenerUnrestrictedFrontingProviderIDWrongRegion,
-			false,
-			true,
-		},
 	}
 
 	for _, testCase := range listenerTestCases {
@@ -182,7 +148,7 @@ func TestListener(t *testing.T) {
 			}
 
 			support := &SupportServices{
-				Config:        &Config{frontingProviderID: frontingProviderID},
+				Config:        &Config{},
 				TacticsServer: tacticsServer,
 			}
 			support.ReplayCache = NewReplayCache(support)

+ 66 - 24
psiphon/server/meek.go

@@ -43,6 +43,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
@@ -344,8 +345,9 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		session,
 		underlyingConn,
 		endPoint,
-		clientIP,
+		endPointGeoIPData,
 		err := server.getSessionOrEndpoint(request, meekCookie)
+
 	if err != nil {
 		// Debug since session cookie errors commonly occur during
 		// normal operation.
@@ -359,9 +361,8 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// request handler.
 
-		geoIPData := server.support.GeoIPService.Lookup(clientIP)
 		handled := server.support.TacticsServer.HandleEndPoint(
-			endPoint, common.GeoIPData(geoIPData), responseWriter, request)
+			endPoint, common.GeoIPData(*endPointGeoIPData), responseWriter, request)
 		if !handled {
 			log.WithTraceFields(LogFields{"endPoint": endPoint}).Info("unhandled endpoint")
 			common.TerminateHTTPConnection(responseWriter, request)
@@ -587,7 +588,7 @@ func checkRangeHeader(request *http.Request) (int, bool) {
 // mode; or the endpoint is returned when the meek cookie indicates endpoint
 // mode.
 func (server *MeekServer) getSessionOrEndpoint(
-	request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, string, error) {
+	request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, *GeoIPData, error) {
 
 	underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
 
@@ -601,7 +602,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 		// TODO: can multiple http client connections using same session cookie
 		// cause race conditions on session struct?
 		session.touch()
-		return existingSessionID, session, underlyingConn, "", "", nil
+		return existingSessionID, session, underlyingConn, "", nil, nil
 	}
 
 	// Determine the client remote address, which is used for geolocation
@@ -610,6 +611,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// headers such as X-Forwarded-For.
 
 	clientIP := strings.Split(request.RemoteAddr, ":")[0]
+	usedProxyForwardedForHeader := false
+	var geoIPData GeoIPData
 
 	if len(server.support.Config.MeekProxyForwardedForHeaders) > 0 {
 		for _, header := range server.support.Config.MeekProxyForwardedForHeaders {
@@ -619,23 +622,29 @@ func (server *MeekServer) getSessionOrEndpoint(
 				// list of IPs (each proxy in a chain). The first IP should be
 				// the client IP.
 				proxyClientIP := strings.Split(value, ",")[0]
-				if net.ParseIP(proxyClientIP) != nil &&
-					server.support.GeoIPService.Lookup(
-						proxyClientIP).Country != GEOIP_UNKNOWN_VALUE {
-
-					clientIP = proxyClientIP
-					break
+				if net.ParseIP(proxyClientIP) != nil {
+					proxyClientGeoIPData := server.support.GeoIPService.Lookup(proxyClientIP)
+					if proxyClientGeoIPData.Country != GEOIP_UNKNOWN_VALUE {
+						usedProxyForwardedForHeader = true
+						clientIP = proxyClientIP
+						geoIPData = proxyClientGeoIPData
+						break
+					}
 				}
 			}
 		}
 	}
 
+	if !usedProxyForwardedForHeader {
+		geoIPData = server.support.GeoIPService.Lookup(clientIP)
+	}
+
 	// The session is new (or expired). Treat the cookie value as a new meek
 	// cookie, extract the payload, and create a new session.
 
 	payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value)
 	if err != nil {
-		return "", nil, nil, "", "", errors.Trace(err)
+		return "", nil, nil, "", nil, errors.Trace(err)
 	}
 
 	// Note: this meek server ignores legacy values PsiphonClientSessionId
@@ -644,7 +653,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	if err != nil {
-		return "", nil, nil, "", "", errors.Trace(err)
+		return "", nil, nil, "", nil, errors.Trace(err)
 	}
 
 	tunnelProtocol := server.listenerTunnelProtocol
@@ -656,7 +665,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 			server.listenerTunnelProtocol,
 			server.support.Config.GetRunningProtocols()) {
 
-			return "", nil, nil, "", "", errors.Tracef(
+			return "", nil, nil, "", nil, errors.Tracef(
 				"invalid client tunnel protocol: %s", clientSessionData.ClientTunnelProtocol)
 		}
 
@@ -669,8 +678,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// rate limit is primarily intended to limit memory resource consumption and
 	// not the overhead incurred by cookie validation.
 
-	if server.rateLimit(clientIP, tunnelProtocol) {
-		return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded")
+	if server.rateLimit(clientIP, geoIPData, tunnelProtocol) {
+		return "", nil, nil, "", nil, errors.TraceNew("rate limit exceeded")
 	}
 
 	// Handle endpoints before enforcing CheckEstablishTunnels.
@@ -678,7 +687,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// handled by servers which would otherwise reject new tunnels.
 
 	if clientSessionData.EndPoint != "" {
-		return "", nil, nil, clientSessionData.EndPoint, clientIP, nil
+		return "", nil, nil, clientSessionData.EndPoint, &geoIPData, nil
 	}
 
 	// Don't create new sessions when not establishing. A subsequent SSH handshake
@@ -686,7 +695,42 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 	if server.support.TunnelServer != nil &&
 		!server.support.TunnelServer.CheckEstablishTunnels() {
-		return "", nil, nil, "", "", errors.TraceNew("not establishing tunnels")
+		return "", nil, nil, "", nil, errors.TraceNew("not establishing tunnels")
+	}
+
+	// Disconnect immediately if the tactics for the client restricts usage of
+	// the fronting provider ID. The probability may be used to influence
+	// usage of a given fronting provider; but when only that provider works
+	// for a given client, and the probability is less than 1.0, the client
+	// can retry until it gets a successful coin flip.
+	//
+	// Clients will also skip candidates with restricted fronting provider IDs.
+	// The client-side probability, RestrictFrontingProviderIDsClientProbability,
+	// is applied independently of the server-side coin flip here.
+	//
+	// At this stage, GeoIP tactics filters are active, but handshake API
+	// parameters are not.
+	//
+	// See the comment in server.LoadConfig regarding fronting provider ID
+	// limitations.
+
+	if protocol.TunnelProtocolUsesFrontedMeek(server.listenerTunnelProtocol) &&
+		server.support.ServerTacticsParametersCache != nil {
+
+		p, err := server.support.ServerTacticsParametersCache.Get(geoIPData)
+		if err != nil {
+			return "", nil, nil, "", nil, errors.Trace(err)
+		}
+
+		if !p.IsNil() &&
+			common.Contains(
+				p.Strings(parameters.RestrictFrontingProviderIDs),
+				server.support.Config.GetFrontingProviderID()) {
+			if p.WeightedCoinFlip(
+				parameters.RestrictFrontingProviderIDsServerProbability) {
+				return "", nil, nil, "", nil, errors.TraceNew("restricted fronting provider")
+			}
+		}
 	}
 
 	// Create a new session
@@ -736,7 +780,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
 		sessionID, err = makeMeekSessionID()
 		if err != nil {
-			return "", nil, nil, "", "", errors.Trace(err)
+			return "", nil, nil, "", nil, errors.Trace(err)
 		}
 	}
 
@@ -748,10 +792,11 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// will close when session.delete calls Close() on the meekConn.
 	server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
 
-	return sessionID, session, underlyingConn, "", "", nil
+	return sessionID, session, underlyingConn, "", nil, nil
 }
 
-func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool {
+func (server *MeekServer) rateLimit(
+	clientIP string, geoIPData GeoIPData, tunnelProtocol string) bool {
 
 	historySize,
 		thresholdSeconds,
@@ -774,9 +819,6 @@ func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool
 
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 
-		// TODO: avoid redundant GeoIP lookups?
-		geoIPData := server.support.GeoIPService.Lookup(clientIP)
-
 		if len(regions) > 0 {
 			if !common.Contains(regions, geoIPData.Country) {
 				return false

+ 75 - 7
psiphon/server/meek_test.go

@@ -25,8 +25,10 @@ import (
 	crypto_rand "crypto/rand"
 	"encoding/base64"
 	"fmt"
+	"io/ioutil"
 	"math/rand"
 	"net"
+	"path/filepath"
 	"sync"
 	"sync/atomic"
 	"syscall"
@@ -38,6 +40,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"golang.org/x/crypto/nacl/box"
 )
 
@@ -245,6 +248,7 @@ func TestMeekResiliency(t *testing.T) {
 		},
 		TrafficRulesSet: &TrafficRulesSet{},
 	}
+	mockSupport.GeoIPService, _ = NewGeoIPService([]string{})
 
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
@@ -401,19 +405,73 @@ func (interruptor *fileDescriptorInterruptor) BindToDevice(fileDescriptor int) (
 }
 
 func TestMeekRateLimiter(t *testing.T) {
-	runTestMeekRateLimiter(t, true)
-	runTestMeekRateLimiter(t, false)
+	runTestMeekAccessControl(t, true, false)
+	runTestMeekAccessControl(t, false, false)
 }
 
-func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
+func TestMeekRestrictFrontingProviders(t *testing.T) {
+	runTestMeekAccessControl(t, false, true)
+	runTestMeekAccessControl(t, false, false)
+}
+
+func runTestMeekAccessControl(t *testing.T, rateLimit, restrictProvider bool) {
 
 	attempts := 10
 
 	allowedConnections := 5
+
 	if !rateLimit {
 		allowedConnections = 10
 	}
 
+	if restrictProvider {
+		allowedConnections = 0
+	}
+
+	// Configure tactics
+
+	frontingProviderID := prng.HexString(8)
+
+	tacticsConfigJSONFormat := `
+    {
+      "RequestPublicKey" : "%s",
+      "RequestPrivateKey" : "%s",
+      "RequestObfuscatedKey" : "%s",
+      "DefaultTactics" : {
+        "TTL" : "60s",
+        "Probability" : 1.0,
+        "Parameters" : {
+          "RestrictFrontingProviderIDs" : ["%s"],
+          "RestrictFrontingProviderIDsServerProbability" : 1.0
+        }
+      }
+    }
+    `
+
+	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err :=
+		tactics.GenerateKeys()
+	if err != nil {
+		t.Fatalf("error generating tactics keys: %s", err)
+	}
+
+	restrictFrontingProviderID := ""
+
+	if restrictProvider {
+		restrictFrontingProviderID = frontingProviderID
+	}
+
+	tacticsConfigJSON := fmt.Sprintf(
+		tacticsConfigJSONFormat,
+		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+		restrictFrontingProviderID)
+
+	tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
+
+	err = ioutil.WriteFile(tacticsConfigFilename, []byte(tacticsConfigJSON), 0600)
+	if err != nil {
+		t.Fatalf("error paving tactics config file: %s", err)
+	}
+
 	// Run meek server
 
 	rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
@@ -424,11 +482,11 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 	meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
 	meekObfuscatedKey := prng.HexString(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 
-	tunnelProtocol := protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK
+	tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
 
 	meekRateLimiterTunnelProtocols := []string{tunnelProtocol}
 	if !rateLimit {
-		meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS}
+		meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_FRONTED_MEEK}
 	}
 
 	mockSupport := &SupportServices{
@@ -436,6 +494,7 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 			MeekObfuscatedKey:              meekObfuscatedKey,
 			MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey,
 			TunnelProtocolPorts:            map[string]int{tunnelProtocol: 0},
+			frontingProviderID:             frontingProviderID,
 		},
 		TrafficRulesSet: &TrafficRulesSet{
 			MeekRateLimiterHistorySize:                   allowedConnections,
@@ -445,6 +504,15 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 			MeekRateLimiterReapHistoryFrequencySeconds:   1,
 		},
 	}
+	mockSupport.GeoIPService, _ = NewGeoIPService([]string{})
+
+	tacticsServer, err := tactics.NewServer(nil, nil, nil, tacticsConfigFilename)
+	if err != nil {
+		t.Fatalf("tactics.NewServer failed: %s", err)
+	}
+
+	mockSupport.TacticsServer = tacticsServer
+	mockSupport.ServerTacticsParametersCache = NewServerTacticsParametersCache(mockSupport)
 
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
@@ -567,8 +635,8 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
 		totalFailures != attempts-totalConnections {
 
 		t.Fatalf(
-			"Unexpected results: %d connections, %d failures",
-			totalConnections, totalFailures)
+			"Unexpected results: %d connections, %d failures, %d allowed",
+			totalConnections, totalFailures, allowedConnections)
 	}
 
 	// Graceful shutdown

+ 1 - 1
psiphon/server/psinet/psinet.go

@@ -396,7 +396,7 @@ func calculateBucketCount(length int) int {
 func bucketizeServerList(servers []*DiscoveryServer, bucketCount int) [][]*DiscoveryServer {
 
 	// This code creates the same partitions as legacy servers:
-	// https://bitbucket.org/psiphon/psiphon-circumvention-system/src/03bc1a7e51e7c85a816e370bb3a6c755fd9c6fee/Automation/psi_ops_discovery.py
+	// https://github.com/Psiphon-Inc/psiphon-automation/blob/685f91a85bcdb33a75a200d936eadcb0686eadd7/Automation/psi_ops_discovery.py
 	//
 	// Both use the same algorithm from:
 	// http://stackoverflow.com/questions/2659900/python-slicing-a-list-into-n-nearly-equal-length-partitions

+ 76 - 6
psiphon/server/server_test.go

@@ -1338,6 +1338,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	expectServerBPFField := ServerBPFEnabled() && doServerTactics
 	expectServerPacketManipulationField := runConfig.doPacketManipulation
 	expectBurstFields := runConfig.doBurstMonitor
+	expectTCPPortForwardDial := runConfig.doTunneledWebRequest
+	expectTCPDataTransfer := runConfig.doTunneledWebRequest && !expectTrafficFailure && !runConfig.doSplitTunnel
+	// Even with expectTrafficFailure, DNS port forwards will succeed
+	expectUDPDataTransfer := runConfig.doTunneledNTPRequest
 
 	select {
 	case logFields := <-serverTunnelLog:
@@ -1347,6 +1351,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectServerBPFField,
 			expectServerPacketManipulationField,
 			expectBurstFields,
+			expectTCPPortForwardDial,
+			expectTCPDataTransfer,
+			expectUDPDataTransfer,
 			logFields)
 		if err != nil {
 			t.Fatalf("invalid server tunnel log fields: %s", err)
@@ -1404,6 +1411,9 @@ func checkExpectedServerTunnelLogFields(
 	expectServerBPFField bool,
 	expectServerPacketManipulationField bool,
 	expectBurstFields bool,
+	expectTCPPortForwardDial bool,
+	expectTCPDataTransfer bool,
+	expectUDPDataTransfer bool,
 	fields map[string]interface{}) error {
 
 	// Limitations:
@@ -1649,6 +1659,66 @@ func checkExpectedServerTunnelLogFields(
 		return fmt.Errorf("unexpected network_type '%s'", fields["network_type"])
 	}
 
+	var checkTCPMetric func(float64) bool
+	if expectTCPPortForwardDial {
+		checkTCPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkTCPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"peak_concurrent_dialing_port_forward_count_tcp",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkTCPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
+	if expectTCPDataTransfer {
+		checkTCPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkTCPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"bytes_up_tcp",
+		"bytes_down_tcp",
+		"peak_concurrent_port_forward_count_tcp",
+		"total_port_forward_count_tcp",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkTCPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
+	var checkUDPMetric func(float64) bool
+	if expectUDPDataTransfer {
+		checkUDPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkUDPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"bytes_up_udp",
+		"bytes_down_udp",
+		"peak_concurrent_port_forward_count_udp",
+		"total_port_forward_count_udp",
+		"total_udpgw_channel_count",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkUDPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
 	return nil
 }
 
@@ -1998,12 +2068,12 @@ func paveTrafficRulesFile(
 
 	allowTCPPorts := TCPPorts
 	allowUDPPorts := UDPPorts
-	disallowTCPPorts := "0"
-	disallowUDPPorts := "0"
+	disallowTCPPorts := "1"
+	disallowUDPPorts := "1"
 
 	if deny {
-		allowTCPPorts = "0"
-		allowUDPPorts = "0"
+		allowTCPPorts = "1"
+		allowUDPPorts = "1"
 		disallowTCPPorts = TCPPorts
 		disallowUDPPorts = UDPPorts
 	}
@@ -2033,8 +2103,8 @@ func paveTrafficRulesFile(
                 "ReadUnthrottledBytes": %d,
                 "WriteUnthrottledBytes": %d
             },
-            "AllowTCPPorts" : [0],
-            "AllowUDPPorts" : [0],
+            "AllowTCPPorts" : [1],
+            "AllowUDPPorts" : [1],
             "MeekRateLimiterHistorySize" : 10,
             "MeekRateLimiterThresholdSeconds" : 1,
             "MeekRateLimiterGarbageCollectionTriggerCount" : 1,

+ 18 - 93
psiphon/server/trafficRules.go

@@ -236,21 +236,21 @@ type TrafficRules struct {
 
 	// AllowTCPPorts specifies a list of TCP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowTCPPorts []int
+	AllowTCPPorts *common.PortList
 
 	// AllowUDPPorts specifies a list of UDP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowUDPPorts []int
+	AllowUDPPorts *common.PortList
 
 	// DisallowTCPPorts specifies a list of TCP ports that are not permitted for
 	// port forwarding. DisallowTCPPorts takes priority over AllowTCPPorts and
 	// AllowSubnets.
-	DisallowTCPPorts []int
+	DisallowTCPPorts *common.PortList
 
 	// DisallowUDPPorts specifies a list of UDP ports that are not permitted for
 	// port forwarding. DisallowUDPPorts takes priority over AllowUDPPorts and
 	// AllowSubnets.
-	DisallowUDPPorts []int
+	DisallowUDPPorts *common.PortList
 
 	// AllowSubnets specifies a list of IP address subnets for which all TCP and
 	// UDP ports are allowed. This list is consulted if a port is disallowed by
@@ -261,11 +261,6 @@ type TrafficRules struct {
 	// client sends an IP address. Domain names are not resolved before checking
 	// AllowSubnets.
 	AllowSubnets []string
-
-	allowTCPPortsLookup    map[int]bool
-	allowUDPPortsLookup    map[int]bool
-	disallowTCPPortsLookup map[int]bool
-	disallowUDPPortsLookup map[int]bool
 }
 
 // RateLimits is a clone of common.RateLimits with pointers
@@ -434,33 +429,11 @@ func (set *TrafficRulesSet) initLookups() {
 
 	initTrafficRulesLookups := func(rules *TrafficRules) {
 
-		if len(rules.AllowTCPPorts) >= intLookupThreshold {
-			rules.allowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowTCPPorts {
-				rules.allowTCPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.AllowUDPPorts) >= intLookupThreshold {
-			rules.allowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowUDPPorts {
-				rules.allowUDPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.DisallowTCPPorts) >= intLookupThreshold {
-			rules.disallowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowTCPPorts {
-				rules.disallowTCPPortsLookup[port] = true
-			}
-		}
+		rules.AllowTCPPorts.OptimizeLookups()
+		rules.AllowUDPPorts.OptimizeLookups()
+		rules.DisallowTCPPorts.OptimizeLookups()
+		rules.DisallowUDPPorts.OptimizeLookups()
 
-		if len(rules.DisallowUDPPorts) >= intLookupThreshold {
-			rules.disallowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowUDPPorts {
-				rules.disallowUDPPortsLookup[port] = true
-			}
-		}
 	}
 
 	initTrafficRulesFilterLookups := func(filter *TrafficRulesFilter) {
@@ -600,14 +573,6 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			intPtr(DEFAULT_MAX_UDP_PORT_FORWARD_COUNT)
 	}
 
-	if trafficRules.AllowTCPPorts == nil {
-		trafficRules.AllowTCPPorts = make([]int, 0)
-	}
-
-	if trafficRules.AllowUDPPorts == nil {
-		trafficRules.AllowUDPPorts = make([]int, 0)
-	}
-
 	if trafficRules.AllowSubnets == nil {
 		trafficRules.AllowSubnets = make([]string, 0)
 	}
@@ -800,22 +765,18 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 		if filteredRules.Rules.AllowTCPPorts != nil {
 			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
-			trafficRules.allowTCPPortsLookup = filteredRules.Rules.allowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowUDPPorts != nil {
 			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
-			trafficRules.allowUDPPortsLookup = filteredRules.Rules.allowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.DisallowTCPPorts != nil {
 			trafficRules.DisallowTCPPorts = filteredRules.Rules.DisallowTCPPorts
-			trafficRules.disallowTCPPortsLookup = filteredRules.Rules.disallowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.DisallowUDPPorts != nil {
 			trafficRules.DisallowUDPPorts = filteredRules.Rules.DisallowUDPPorts
-			trafficRules.disallowUDPPortsLookup = filteredRules.Rules.disallowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowSubnets != nil {
@@ -837,34 +798,16 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
-	if len(rules.DisallowTCPPorts) > 0 {
-		if rules.disallowTCPPortsLookup != nil {
-			if rules.disallowTCPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowTCPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowTCPPorts.Lookup(port) {
+		return false
 	}
 
-	if len(rules.AllowTCPPorts) == 0 {
+	if rules.AllowTCPPorts.IsEmpty() {
 		return true
 	}
 
-	if rules.allowTCPPortsLookup != nil {
-		if rules.allowTCPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowTCPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowTCPPorts.Lookup(port) {
+		return true
 	}
 
 	return rules.allowSubnet(remoteIP)
@@ -872,34 +815,16 @@ func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
 func (rules *TrafficRules) AllowUDPPort(remoteIP net.IP, port int) bool {
 
-	if len(rules.DisallowUDPPorts) > 0 {
-		if rules.disallowUDPPortsLookup != nil {
-			if rules.disallowUDPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowUDPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowUDPPorts.Lookup(port) {
+		return false
 	}
 
-	if len(rules.AllowUDPPorts) == 0 {
+	if rules.AllowUDPPorts.IsEmpty() {
 		return true
 	}
 
-	if rules.allowUDPPortsLookup != nil {
-		if rules.allowUDPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowUDPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowUDPPorts.Lookup(port) {
+		return true
 	}
 
 	return rules.allowSubnet(remoteIP)

+ 40 - 18
psiphon/server/tunnelServer.go

@@ -1260,8 +1260,10 @@ type sshClient struct {
 	isFirstTunnelInSession               bool
 	supportsServerRequests               bool
 	handshakeState                       handshakeState
-	udpChannel                           ssh.Channel
+	udpgwChannelHandler                  *udpgwPortForwardMultiplexer
+	totalUdpgwChannelCount               int
 	packetTunnelChannel                  ssh.Channel
+	totalPacketTunnelChannelCount        int
 	trafficRules                         TrafficRules
 	tcpTrafficState                      trafficState
 	udpTrafficState                      trafficState
@@ -2495,11 +2497,11 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 
 	// Intercept TCP port forwards to a specified udpgw server and handle directly.
 	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
-	isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
+	isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
 		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
 			net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
 
-	if isUDPChannel {
+	if isUdpgwChannel {
 
 		// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
 		// own worker goroutine.
@@ -2507,7 +2509,7 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 		waitGroup.Add(1)
 		go func(channel ssh.NewChannel) {
 			defer waitGroup.Done()
-			sshClient.handleUDPChannel(channel)
+			sshClient.handleUdpgwChannel(channel)
 		}(newChannel)
 
 	} else {
@@ -2558,20 +2560,39 @@ func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
 		sshClient.packetTunnelChannel.Close()
 	}
 	sshClient.packetTunnelChannel = channel
+	sshClient.totalPacketTunnelChannelCount += 1
 	sshClient.Unlock()
 }
 
-// setUDPChannel sets the single UDP channel for this sshClient.
-// Each sshClient may have only one concurrent UDP channel. Each
-// UDP channel multiplexes many UDP port forwards via the udpgw
-// protocol. Any existing UDP channel is closed.
-func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
+// setUdpgwChannelHandler sets the single udpgw channel handler for this
+// sshClient. Each sshClient may have only one concurrent udpgw
+// channel/handler. Each udpgw channel multiplexes many UDP port forwards via
+// the udpgw protocol. Any existing udpgw channel/handler is closed.
+func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) bool {
 	sshClient.Lock()
-	if sshClient.udpChannel != nil {
-		sshClient.udpChannel.Close()
+	if sshClient.udpgwChannelHandler != nil {
+		previousHandler := sshClient.udpgwChannelHandler
+		sshClient.udpgwChannelHandler = nil
+
+		// stop must be run without holding the sshClient mutex lock, as the
+		// udpgw goroutines may attempt to lock the same mutex. For example,
+		// udpgwPortForwardMultiplexer.run calls sshClient.establishedPortForward
+		// which calls sshClient.allocatePortForward.
+		sshClient.Unlock()
+		previousHandler.stop()
+		sshClient.Lock()
+
+		// In case some other channel has set the sshClient.udpgwChannelHandler
+		// in the meantime, fail. The caller should discard this channel/handler.
+		if sshClient.udpgwChannelHandler != nil {
+			sshClient.Unlock()
+			return false
+		}
 	}
-	sshClient.udpChannel = channel
+	sshClient.udpgwChannelHandler = udpgwChannelHandler
+	sshClient.totalUdpgwChannelCount += 1
 	sshClient.Unlock()
+	return true
 }
 
 var serverTunnelStatParams = append(
@@ -2616,6 +2637,8 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
+	logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount
+	logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount
 
 	logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
 	logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes
@@ -2645,12 +2668,11 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 		}
 	}
 
-	sshClient.Unlock()
-
-	// Note: unlock before use is only safe as long as referenced sshClient data,
-	// such as slices in handshakeState, is read-only after initially set.
-
+	// Retain lock when invoking LogRawFieldsWithTimestamp to block any
+	// concurrent writes to variables referenced by logFields.
 	log.LogRawFieldsWithTimestamp(logFields)
+
+	sshClient.Unlock()
 }
 
 var blocklistHitsStatParams = []requestParamSpec{
@@ -2827,7 +2849,7 @@ func (sshClient *sshClient) enqueueDisallowedTrafficAlertRequest() {
 
 	sshClient.enqueueAlertRequest(
 		protocol.AlertRequest{
-			Reason:     protocol.PSIPHON_API_ALERT_DISALLOWED_TRAFFIC,
+			Reason:     reason,
 			ActionURLs: actionURLs,
 		})
 }

+ 105 - 54
psiphon/server/udp.go

@@ -25,7 +25,6 @@ import (
 	"fmt"
 	"io"
 	"net"
-	"runtime/debug"
 	"sync"
 	"sync/atomic"
 
@@ -35,7 +34,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 )
 
-// handleUDPChannel implements UDP port forwarding. A single UDP
+// handleUdpgwChannel implements UDP port forwarding. A single UDP
 // SSH channel follows the udpgw protocol, which multiplexes many
 // UDP port forwards.
 //
@@ -43,10 +42,10 @@ import (
 // Copyright (c) 2009, Ambroz Bizjak <ambrop7@gmail.com>
 // https://github.com/ambrop72/badvpn
 //
-func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 
 	// Accept this channel immediately. This channel will replace any
-	// previously existing UDP channel for this client.
+	// previously existing udpgw channel for this client.
 
 	sshChannel, requests, err := newChannel.Accept()
 	if err != nil {
@@ -58,33 +57,81 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	go ssh.DiscardRequests(requests)
 	defer sshChannel.Close()
 
-	sshClient.setUDPChannel(sshChannel)
-
-	multiplexer := &udpPortForwardMultiplexer{
+	multiplexer := &udpgwPortForwardMultiplexer{
 		sshClient:      sshClient,
 		sshChannel:     sshChannel,
-		portForwards:   make(map[uint16]*udpPortForward),
+		portForwards:   make(map[uint16]*udpgwPortForward),
 		portForwardLRU: common.NewLRUConns(),
 		relayWaitGroup: new(sync.WaitGroup),
+		runWaitGroup:   new(sync.WaitGroup),
+	}
+
+	multiplexer.runWaitGroup.Add(1)
+
+	// setUdpgwChannelHandler will close any existing
+	// udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
+	// goroutines to first terminate and all UDP socket resources to be
+	// cleaned up.
+	//
+	// This synchronous shutdown also ensures that the
+	// concurrentPortForwardCount is reduced to 0 before installing the new
+	// udpgwPortForwardMultiplexer and its LRU object. If the older handler
+	// were to dangle with open port forwards, and concurrentPortForwardCount
+	// were to hit the max, the wrong LRU, the new one, would be used to
+	// close the LRU port forward.
+	//
+	// Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
+	// runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
+	// channel) before initialized.
+
+	if !sshClient.setUdpgwChannelHandler(multiplexer) {
+		// setUdpgwChannelHandler returns false if some other SSH channel
+		// calls setUdpgwChannelHandler in the middle of this call. In that
+		// case, discard this channel: the client's latest udpgw channel is
+		// retained.
+		return
 	}
+
 	multiplexer.run()
+	multiplexer.runWaitGroup.Done()
 }
 
-type udpPortForwardMultiplexer struct {
+type udpgwPortForwardMultiplexer struct {
 	sshClient            *sshClient
 	sshChannelWriteMutex sync.Mutex
 	sshChannel           ssh.Channel
 	portForwardsMutex    sync.Mutex
-	portForwards         map[uint16]*udpPortForward
+	portForwards         map[uint16]*udpgwPortForward
 	portForwardLRU       *common.LRUConns
 	relayWaitGroup       *sync.WaitGroup
+	runWaitGroup         *sync.WaitGroup
+}
+
+func (mux *udpgwPortForwardMultiplexer) stop() {
+
+	// udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
+	//
+	// stop closes the udpgw SSH channel, which will cause the run goroutine
+	// to exit its message read loop and await closure of all relayDownstream
+	// goroutines. Closing all port forward UDP conns will cause all
+	// relayDownstream to exit.
+
+	_ = mux.sshChannel.Close()
+
+	mux.portForwardsMutex.Lock()
+	for _, portForward := range mux.portForwards {
+		_ = portForward.conn.Close()
+	}
+	mux.portForwardsMutex.Unlock()
+
+	mux.runWaitGroup.Wait()
 }
 
-func (mux *udpPortForwardMultiplexer) run() {
+func (mux *udpgwPortForwardMultiplexer) run() {
 
-	// In a loop, read udpgw messages from the client to this channel. Each message is
-	// a UDP packet to send upstream either via a new port forward, or on an existing
-	// port forward.
+	// In a loop, read udpgw messages from the client to this channel. Each
+	// message contains a UDP packet to send upstream either via a new port
+	// forward, or on an existing port forward.
 	//
 	// A goroutine is run to read downstream packets for each UDP port forward. All read
 	// packets are encapsulated in udpgw protocol and sent down the channel to the client.
@@ -92,16 +139,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 	// When the client disconnects or the server shuts down, the channel will close and
 	// readUdpgwMessage will exit with EOF.
 
-	// Recover from and log any unexpected panics caused by udpgw input handling bugs.
-	// Note: this covers the run() goroutine only and not relayDownstream() goroutines.
-	defer func() {
-		if e := recover(); e != nil {
-			err := errors.Tracef(
-				"udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack())
-			log.WithTraceFields(LogFields{"error": err}).Warning("run failed")
-		}
-	}()
-
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
 	for {
 		// Note: message.packet points to the reusable memory in "buffer".
@@ -119,27 +156,37 @@ func (mux *udpPortForwardMultiplexer) run() {
 		portForward := mux.portForwards[message.connID]
 		mux.portForwardsMutex.Unlock()
 
-		if portForward != nil && message.discardExistingConn {
+		// In the udpgw protocol, an existing port forward is closed when
+		// either the discard flag is set or the remote address has changed.
+
+		if portForward != nil &&
+			(message.discardExistingConn ||
+				!bytes.Equal(portForward.remoteIP, message.remoteIP) ||
+				portForward.remotePort != message.remotePort) {
+
 			// The port forward's goroutine will complete cleanup, including
 			// tallying stats and calling sshClient.closedPortForward.
 			// portForward.conn.Close() will signal this shutdown.
-			// TODO: wait for goroutine to exit before proceeding?
 			portForward.conn.Close()
-			portForward = nil
-		}
-
-		if portForward != nil {
 
-			// Verify that portForward remote address matches latest message
+			// Synchronously await the termination of the relayDownstream
+			// goroutine. This ensures that the previous goroutine won't
+			// invoke removePortForward, with the connID that will be reused
+			// for the new port forward, after this point.
+			//
+			// Limitation: this synchronous shutdown cannot prevent a "wrong
+			// remote address" error on the badvpn udpgw client, which occurs
+			// when the client recycles a port forward (setting discard) but
+			// receives, from the server, a udpgw message containing the old
+			// remote address for the previous port forward with the same
+			// conn ID. That downstream message from the server may be in
+			// flight in the SSH channel when the client discard message arrives.
+			portForward.relayWaitGroup.Wait()
 
-			if !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
-				portForward.remotePort != message.remotePort {
-
-				log.WithTrace().Warning("UDP port forward remote address mismatch")
-				continue
-			}
+			portForward = nil
+		}
 
-		} else {
+		if portForward == nil {
 
 			// Create a new port forward
 
@@ -237,17 +284,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 			}
 
-			portForward = &udpPortForward{
-				connID:       message.connID,
-				preambleSize: message.preambleSize,
-				remoteIP:     message.remoteIP,
-				remotePort:   message.remotePort,
-				dialIP:       dialIP,
-				conn:         conn,
-				lruEntry:     lruEntry,
-				bytesUp:      0,
-				bytesDown:    0,
-				mux:          mux,
+			portForward = &udpgwPortForward{
+				connID:         message.connID,
+				preambleSize:   message.preambleSize,
+				remoteIP:       message.remoteIP,
+				remotePort:     message.remotePort,
+				dialIP:         dialIP,
+				conn:           conn,
+				lruEntry:       lruEntry,
+				bytesUp:        0,
+				bytesDown:      0,
+				relayWaitGroup: new(sync.WaitGroup),
+				mux:            mux,
 			}
 
 			if message.forwardDNS {
@@ -258,6 +306,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 
+			portForward.relayWaitGroup.Add(1)
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 		}
@@ -276,7 +325,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 	}
 
-	// Cleanup all UDP port forward workers when exiting
+	// Cleanup all udpgw port forward workers when exiting
 
 	mux.portForwardsMutex.Lock()
 	for _, portForward := range mux.portForwards {
@@ -288,13 +337,13 @@ func (mux *udpPortForwardMultiplexer) run() {
 	mux.relayWaitGroup.Wait()
 }
 
-func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
+func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) {
 	mux.portForwardsMutex.Lock()
 	delete(mux.portForwards, connID)
 	mux.portForwardsMutex.Unlock()
 }
 
-type udpPortForward struct {
+type udpgwPortForward struct {
 	// Note: 64-bit ints used with atomic operations are placed
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
@@ -309,10 +358,12 @@ type udpPortForward struct {
 	dialIP            net.IP
 	conn              net.Conn
 	lruEntry          *common.LRUConnsEntry
-	mux               *udpPortForwardMultiplexer
+	relayWaitGroup    *sync.WaitGroup
+	mux               *udpgwPortForwardMultiplexer
 }
 
-func (portForward *udpPortForward) relayDownstream() {
+func (portForward *udpgwPortForward) relayDownstream() {
+	defer portForward.relayWaitGroup.Done()
 	defer portForward.mux.relayWaitGroup.Done()
 
 	// Downstream UDP packets are read into the reusable memory

+ 5 - 0
psiphon/upstreamproxy/go-ntlm/README.md

@@ -1,3 +1,8 @@
+
+This is a fork of the package that previously existed at https://github.com/ThomsonReutersEikon/go-ntlm/ntlm. It contains several bug fixes, tagged with "[Psiphon]".
+
+Original github.com/ThomsonReutersEikon/go-ntlm/ntlm README:
+
 # NTLM Implementation for Go
 
 This is a native implementation of NTLM for Go that was implemented using the Microsoft MS-NLMP documentation available at http://msdn.microsoft.com/en-us/library/cc236621.aspx.

+ 23 - 5
psiphon/upstreamproxy/go-ntlm/ntlm/av_pairs.go

@@ -6,6 +6,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"encoding/hex"
+	"errors"
 	"fmt"
 )
 
@@ -51,13 +52,16 @@ func (p *AvPairs) AddAvPair(avId AvPairType, bytes []byte) {
 	p.List = append(p.List, *a)
 }
 
-func ReadAvPairs(data []byte) *AvPairs {
+func ReadAvPairs(data []byte) (*AvPairs, error) {
 	pairs := new(AvPairs)
 
 	// Get the number of AvPairs and allocate enough AvPair structures to hold them
 	offset := 0
 	for i := 0; len(data) > 0 && i < 11; i++ {
-		pair := ReadAvPair(data, offset)
+		pair, err := ReadAvPair(data, offset)
+		if err != nil {
+			return nil, err
+		}
 		offset = offset + 4 + int(pair.AvLen)
 		pairs.List = append(pairs.List, *pair)
 		if pair.AvId == MsvAvEOL {
@@ -65,7 +69,7 @@ func ReadAvPairs(data []byte) *AvPairs {
 		}
 	}
 
-	return pairs
+	return pairs, nil
 }
 
 func (p *AvPairs) Bytes() (result []byte) {
@@ -131,12 +135,26 @@ type AvPair struct {
 	Value []byte
 }
 
-func ReadAvPair(data []byte, offset int) *AvPair {
+func ReadAvPair(data []byte, offset int) (*AvPair, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(data) < offset+4 {
+		return nil, errors.New("invalid AvPair")
+	}
+
 	pair := new(AvPair)
 	pair.AvId = AvPairType(binary.LittleEndian.Uint16(data[offset : offset+2]))
 	pair.AvLen = binary.LittleEndian.Uint16(data[offset+2 : offset+4])
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(data) < offset+4+int(pair.AvLen) {
+		return nil, errors.New("invalid AvPair")
+	}
+
 	pair.Value = data[offset+4 : offset+4+int(pair.AvLen)]
-	return pair
+	return pair, nil
 }
 
 func (a *AvPair) UnicodeStringValue() string {

+ 37 - 5
psiphon/upstreamproxy/go-ntlm/ntlm/challenge_responses.go

@@ -21,6 +21,13 @@ func (n *NtlmV1Response) String() string {
 }
 
 func ReadNtlmV1Response(bytes []byte) (*NtlmV1Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 24 {
+		return nil, errors.New("invalid NTLM v1 response")
+	}
+
 	r := new(NtlmV1Response)
 	r.Response = bytes[0:24]
 	return r, nil
@@ -84,6 +91,13 @@ func (n *NtlmV2Response) String() string {
 }
 
 func ReadNtlmV2Response(bytes []byte) (*NtlmV2Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 45 {
+		return nil, errors.New("invalid NTLM v2 response")
+	}
+
 	r := new(NtlmV2Response)
 	r.Response = bytes[0:16]
 	r.NtlmV2ClientChallenge = new(NtlmV2ClientChallenge)
@@ -103,7 +117,11 @@ func ReadNtlmV2Response(bytes []byte) (*NtlmV2Response, error) {
 	c.ChallengeFromClient = bytes[32:40]
 	// Ignoring - 4 bytes reserved
 	// c.Reserved3
-	c.AvPairs = ReadAvPairs(bytes[44:])
+	var err error
+	c.AvPairs, err = ReadAvPairs(bytes[44:])
+	if err != nil {
+		return nil, err
+	}
 	return r, nil
 }
 
@@ -114,10 +132,17 @@ type LmV1Response struct {
 	Response []byte
 }
 
-func ReadLmV1Response(bytes []byte) *LmV1Response {
+func ReadLmV1Response(bytes []byte) (*LmV1Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 24 {
+		return nil, errors.New("invalid LM v1 response")
+	}
+
 	r := new(LmV1Response)
 	r.Response = bytes[0:24]
-	return r
+	return r, nil
 }
 
 func (l *LmV1Response) String() string {
@@ -136,11 +161,18 @@ type LmV2Response struct {
 	ChallengeFromClient []byte
 }
 
-func ReadLmV2Response(bytes []byte) *LmV2Response {
+func ReadLmV2Response(bytes []byte) (*LmV2Response, error) {
+
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < 24 {
+		return nil, errors.New("invalid LM v2 response")
+	}
+
 	r := new(LmV2Response)
 	r.Response = bytes[0:16]
 	r.ChallengeFromClient = bytes[16:24]
-	return r
+	return r, nil
 }
 
 func (l *LmV2Response) String() string {

+ 38 - 4
psiphon/upstreamproxy/go-ntlm/ntlm/message_authenticate.go

@@ -38,7 +38,7 @@ type AuthenticateMessage struct {
 	/// MS-NLMP 2.2.1.3 - In connectionless mode, a NEGOTIATE structure that contains a set of bit flags (section 2.2.2.5) and represents the
 	// conclusion of negotiation—the choices the client has made from the options the server offered in the CHALLENGE_MESSAGE.
 	// In connection-oriented mode, a NEGOTIATE structure that contains the set of bit flags (section 2.2.2.5) negotiated in
-	// the previous 
+	// the previous
 	NegotiateFlags uint32 // 4 bytes
 
 	// Version (8 bytes): A VERSION structure (section 2.2.2.10) that is present only when the NTLMSSP_NEGOTIATE_VERSION
@@ -56,6 +56,12 @@ type AuthenticateMessage struct {
 func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessage, error) {
 	am := new(AuthenticateMessage)
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < 12 {
+		return nil, errors.New("invalid authenticate message")
+	}
+
 	am.Signature = body[0:8]
 	if !bytes.Equal(am.Signature, []byte("NTLMSSP\x00")) {
 		return nil, errors.New("Invalid NTLM message signature")
@@ -74,9 +80,12 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 	}
 
 	if ntlmVersion == 2 {
-		am.LmV2Response = ReadLmV2Response(am.LmChallengeResponse.Payload)
+		am.LmV2Response, err = ReadLmV2Response(am.LmChallengeResponse.Payload)
 	} else {
-		am.LmV1Response = ReadLmV1Response(am.LmChallengeResponse.Payload)
+		am.LmV1Response, err = ReadLmV1Response(am.LmChallengeResponse.Payload)
+	}
+	if err != nil {
+		return nil, err
 	}
 
 	am.NtChallengeResponseFields, err = ReadBytePayload(20, body)
@@ -90,7 +99,6 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 	} else {
 		am.NtlmV1Response, err = ReadNtlmV1Response(am.NtChallengeResponseFields.Payload)
 	}
-
 	if err != nil {
 		return nil, err
 	}
@@ -124,11 +132,24 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 		}
 		offset = offset + 8
 
+		// [Psiphon]
+		// Don't panic on malformed remote input.
+		if len(body) < offset+4 {
+			return nil, errors.New("invalid authenticate message")
+		}
+
 		am.NegotiateFlags = binary.LittleEndian.Uint32(body[offset : offset+4])
 		offset = offset + 4
 
 		// Version (8 bytes): A VERSION structure (section 2.2.2.10) that is present only when the NTLMSSP_NEGOTIATE_VERSION flag is set in the NegotiateFlags field. This structure is used for debugging purposes only. In normal protocol messages, it is ignored and does not affect the NTLM message processing.<9>
 		if NTLMSSP_NEGOTIATE_VERSION.IsSet(am.NegotiateFlags) {
+
+			// [Psiphon]
+			// Don't panic on malformed remote input.
+			if len(body) < offset+8 {
+				return nil, errors.New("invalid authenticate message")
+			}
+
 			am.Version, err = ReadVersionStruct(body[offset : offset+8])
 			if err != nil {
 				return nil, err
@@ -144,12 +165,25 @@ func ParseAuthenticateMessage(body []byte, ntlmVersion int) (*AuthenticateMessag
 		// there is a MIC and read it out.
 		var lowestOffset = am.getLowestPayloadOffset()
 		if lowestOffset > offset {
+
+			// [Psiphon]
+			// Don't panic on malformed remote input.
+			if len(body) < offset+16 {
+				return nil, errors.New("invalid authenticate message")
+			}
+
 			// MIC - 16 bytes
 			am.Mic = body[offset : offset+16]
 			offset = offset + 16
 		}
 	}
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < offset {
+		return nil, errors.New("invalid authenticate message")
+	}
+
 	am.Payload = body[offset:]
 
 	return am, nil

+ 23 - 1
psiphon/upstreamproxy/go-ntlm/ntlm/message_challenge.go

@@ -56,6 +56,12 @@ type ChallengeMessage struct {
 func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 	challenge := new(ChallengeMessage)
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < 40 {
+		return nil, errors.New("invalid challenge message")
+	}
+
 	challenge.Signature = body[0:8]
 	if !bytes.Equal(challenge.Signature, []byte("NTLMSSP\x00")) {
 		return challenge, errors.New("Invalid NTLM message signature")
@@ -84,11 +90,21 @@ func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 		return nil, err
 	}
 
-	challenge.TargetInfo = ReadAvPairs(challenge.TargetInfoPayloadStruct.Payload)
+	challenge.TargetInfo, err = ReadAvPairs(challenge.TargetInfoPayloadStruct.Payload)
+	if err != nil {
+		return nil, err
+	}
 
 	offset := 48
 
 	if NTLMSSP_NEGOTIATE_VERSION.IsSet(challenge.NegotiateFlags) {
+
+		// [Psiphon]
+		// Don't panic on malformed remote input.
+		if len(body) < offset+8 {
+			return nil, errors.New("invalid challenge message")
+		}
+
 		challenge.Version, err = ReadVersionStruct(body[offset : offset+8])
 		if err != nil {
 			return nil, err
@@ -96,6 +112,12 @@ func ParseChallengeMessage(body []byte) (*ChallengeMessage, error) {
 		offset = offset + 8
 	}
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(body) < offset {
+		return nil, errors.New("invalid challenge message")
+	}
+
 	challenge.Payload = body[offset:]
 
 	return challenge, nil

+ 6 - 1
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1.go

@@ -343,7 +343,12 @@ func (n *V1ClientSession) GenerateAuthenticateMessage() (am *AuthenticateMessage
 	am.NtChallengeResponseFields, _ = CreateBytePayload(n.ntChallengeResponse)
 	am.DomainName, _ = CreateStringPayload(n.userDomain)
 	am.UserName, _ = CreateStringPayload(n.user)
-	am.Workstation, _ = CreateStringPayload("SQUAREMILL")
+
+	// [Psiphon]
+	// Set a blank workstation value, which is less distinguishable than the previous hard-coded value.
+	// See also: https://github.com/Azure/go-ntlmssp/commit/5e29b886690f00c76b876ae9ab8e31e7c3509203.
+
+	am.Workstation, _ = CreateStringPayload("")
 	am.EncryptedRandomSessionKey, _ = CreateBytePayload(n.encryptedRandomSessionKey)
 	am.NegotiateFlags = n.NegotiateFlags
 	am.Version = &VersionStruct{ProductMajorVersion: uint8(5), ProductMinorVersion: uint8(1), ProductBuild: uint16(2600), NTLMRevisionCurrent: uint8(15)}

+ 5 - 5
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv1_test.go

@@ -42,14 +42,14 @@ func checkV1Value(t *testing.T, name string, value []byte, expected string, err
 // would authenticate. This was due to a bug in the MS-NLMP docs. This tests for that issue
 func TestNtlmV1ExtendedSessionSecurity(t *testing.T) {
 	// NTLMv1 with extended session security
-  challengeMessage := "TlRMTVNTUAACAAAAAAAAADgAAABVgphiRy3oSZvn1I4AAAAAAAAAAKIAogA4AAAABQEoCgAAAA8CAA4AUgBFAFUAVABFAFIAUwABABwAVQBLAEIAUAAtAEMAQgBUAFIATQBGAEUAMAA2AAQAFgBSAGUAdQB0AGUAcgBzAC4AbgBlAHQAAwA0AHUAawBiAHAALQBjAGIAdAByAG0AZgBlADAANgAuAFIAZQB1AHQAZQByAHMALgBuAGUAdAAFABYAUgBlAHUAdABlAHIAcwAuAG4AZQB0AAAAAAA="
-  authenticateMessage := "TlRMTVNTUAADAAAAGAAYAJgAAAAYABgAsAAAAAAAAABIAAAAOgA6AEgAAAAWABYAggAAABAAEADIAAAAVYKYYgUCzg4AAAAPMQAwADAAMAAwADEALgB3AGMAcABAAHQAaABvAG0AcwBvAG4AcgBlAHUAdABlAHIAcwAuAGMAbwBtAE4AWQBDAFMATQBTAEcAOQA5ADAAOQBRWAK3h/TIywAAAAAAAAAAAAAAAAAAAAA3tp89kZU1hs1XZp7KTyGm3XsFAT9stEDW9YXDaeYVBmBcBb//2FOu"
+	challengeMessage := "TlRMTVNTUAACAAAAAAAAADgAAABVgphiRy3oSZvn1I4AAAAAAAAAAKIAogA4AAAABQEoCgAAAA8CAA4AUgBFAFUAVABFAFIAUwABABwAVQBLAEIAUAAtAEMAQgBUAFIATQBGAEUAMAA2AAQAFgBSAGUAdQB0AGUAcgBzAC4AbgBlAHQAAwA0AHUAawBiAHAALQBjAGIAdAByAG0AZgBlADAANgAuAFIAZQB1AHQAZQByAHMALgBuAGUAdAAFABYAUgBlAHUAdABlAHIAcwAuAG4AZQB0AAAAAAA="
+	authenticateMessage := "TlRMTVNTUAADAAAAGAAYAJgAAAAYABgAsAAAAAAAAABIAAAAOgA6AEgAAAAWABYAggAAABAAEADIAAAAVYKYYgUCzg4AAAAPMQAwADAAMAAwADEALgB3AGMAcABAAHQAaABvAG0AcwBvAG4AcgBlAHUAdABlAHIAcwAuAGMAbwBtAE4AWQBDAFMATQBTAEcAOQA5ADAAOQBRWAK3h/TIywAAAAAAAAAAAAAAAAAAAAA3tp89kZU1hs1XZp7KTyGm3XsFAT9stEDW9YXDaeYVBmBcBb//2FOu"
 
 	challengeData, _ := base64.StdEncoding.DecodeString(challengeMessage)
 	c, _ := ParseChallengeMessage(challengeData)
 
-  authenticateData, _ := base64.StdEncoding.DecodeString(authenticateMessage)
-  msg, err := ParseAuthenticateMessage(authenticateData, 1)
+	authenticateData, _ := base64.StdEncoding.DecodeString(authenticateMessage)
+	msg, err := ParseAuthenticateMessage(authenticateData, 1)
 	if err != nil {
 		t.Errorf("Could not process authenticate message: %s", err)
 	}
@@ -62,7 +62,7 @@ func TestNtlmV1ExtendedSessionSecurity(t *testing.T) {
 	context.SetServerChallenge(c.ServerChallenge)
 	err = context.ProcessAuthenticateMessage(msg)
 	if err == nil {
-		t.Errorf("This message should have failed to authenticate, but it passed", err)
+		t.Errorf("This message should have failed to authenticate, but it passed")
 	}
 }
 

+ 6 - 2
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2.go

@@ -88,7 +88,6 @@ func (n *V2Session) Sign(message []byte) ([]byte, error) {
 	return nil, nil
 }
 
-//Mildly ghetto that we expose this
 func NtlmVCommonMac(message []byte, sequenceNumber int, sealingKey, signingKey []byte, NegotiateFlags uint32) []byte {
 	var handle *rc4P.Cipher
 	// TODO: Need to keep track of the sequence number for connection oriented NTLM
@@ -378,7 +377,12 @@ func (n *V2ClientSession) GenerateAuthenticateMessage() (am *AuthenticateMessage
 	am.NtChallengeResponseFields, _ = CreateBytePayload(n.ntChallengeResponse)
 	am.DomainName, _ = CreateStringPayload(n.userDomain)
 	am.UserName, _ = CreateStringPayload(n.user)
-	am.Workstation, _ = CreateStringPayload("SQUAREMILL")
+
+	// [Psiphon]
+	// Set a blank workstation value, which is less distinguishable than the previous hard-coded value.
+	// See also: https://github.com/Azure/go-ntlmssp/commit/5e29b886690f00c76b876ae9ab8e31e7c3509203.
+
+	am.Workstation, _ = CreateStringPayload("")
 	am.EncryptedRandomSessionKey, _ = CreateBytePayload(n.encryptedRandomSessionKey)
 	am.NegotiateFlags = n.NegotiateFlags
 	am.Mic = make([]byte, 16)

+ 1 - 1
psiphon/upstreamproxy/go-ntlm/ntlm/ntlmv2_test.go

@@ -172,7 +172,7 @@ func TestNTLMv2WithDomain(t *testing.T) {
 
 	err := server.ProcessAuthenticateMessage(a)
 	if err != nil {
-		t.Error("Could not process authenticate message: %s\n", err)
+		t.Errorf("Could not process authenticate message: %s\n", err)
 	}
 }
 

+ 14 - 0
psiphon/upstreamproxy/go-ntlm/ntlm/payload.go

@@ -6,6 +6,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"encoding/hex"
+	"errors"
 )
 
 const (
@@ -80,6 +81,12 @@ func ReadBytePayload(startByte int, bytes []byte) (*PayloadStruct, error) {
 func ReadPayloadStruct(startByte int, bytes []byte, PayloadType int) (*PayloadStruct, error) {
 	p := new(PayloadStruct)
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(bytes) < startByte+8 {
+		return nil, errors.New("invalid payload")
+	}
+
 	p.Type = PayloadType
 	p.Len = binary.LittleEndian.Uint16(bytes[startByte : startByte+2])
 	p.MaxLen = binary.LittleEndian.Uint16(bytes[startByte+2 : startByte+4])
@@ -87,6 +94,13 @@ func ReadPayloadStruct(startByte int, bytes []byte, PayloadType int) (*PayloadSt
 
 	if p.Len > 0 {
 		endOffset := p.Offset + uint32(p.Len)
+
+		// [Psiphon]
+		// Don't panic on malformed remote input.
+		if len(bytes) < int(endOffset) {
+			return nil, errors.New("invalid payload")
+		}
+
 		p.Payload = bytes[p.Offset:endOffset]
 	}
 

+ 7 - 0
psiphon/upstreamproxy/go-ntlm/ntlm/version.go

@@ -5,6 +5,7 @@ package ntlm
 import (
 	"bytes"
 	"encoding/binary"
+	"errors"
 	"fmt"
 )
 
@@ -19,6 +20,12 @@ type VersionStruct struct {
 func ReadVersionStruct(structSource []byte) (*VersionStruct, error) {
 	versionStruct := new(VersionStruct)
 
+	// [Psiphon]
+	// Don't panic on malformed remote input.
+	if len(structSource) < 8 {
+		return nil, errors.New("invalid version struct")
+	}
+
 	versionStruct.ProductMajorVersion = uint8(structSource[0])
 	versionStruct.ProductMinorVersion = uint8(structSource[1])
 	versionStruct.ProductBuild = binary.LittleEndian.Uint16(structSource[2:4])

+ 7 - 1
vendor/github.com/Psiphon-Labs/tls-tris/key_agreement.go

@@ -72,7 +72,13 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello
 		return nil, nil, err
 	}
 
-	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), pk.(*rsa.PublicKey), preMasterSecret)
+	// [Psiphon]
+	// Backport fix: https://github.com/golang/go/commit/58bc454a11d4b3dbc03f44dfcabb9068a9c076f4
+	rsaKey, ok := pk.(*rsa.PublicKey)
+	if !ok {
+		return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
+	}
+	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
 	if err != nil {
 		return nil, nil, err
 	}

+ 1 - 1
vendor/github.com/miekg/dns/Makefile.release

@@ -1,7 +1,7 @@
 # Makefile for releasing.
 #
 # The release is controlled from version.go. The version found there is
-# used to tag the git repo, we're not building any artifects so there is nothing
+# used to tag the git repo, we're not building any artifacts so there is nothing
 # to upload to github.
 #
 # * Up the version in version.go

+ 14 - 6
vendor/github.com/miekg/dns/README.md

@@ -26,7 +26,6 @@ avoiding breaking changes wherever reasonable. We support the last two versions
 A not-so-up-to-date-list-that-may-be-actually-current:
 
 * https://github.com/coredns/coredns
-* https://cloudflare.com
 * https://github.com/abh/geodns
 * https://github.com/baidu/bfe
 * http://www.statdns.com/
@@ -42,11 +41,9 @@ A not-so-up-to-date-list-that-may-be-actually-current:
 * https://github.com/StalkR/dns-reverse-proxy
 * https://github.com/tianon/rawdns
 * https://mesosphere.github.io/mesos-dns/
-* https://pulse.turbobytes.com/
 * https://github.com/fcambus/statzone
 * https://github.com/benschw/dns-clb-go
 * https://github.com/corny/dnscheck for <http://public-dns.info/>
-* https://namesmith.io
 * https://github.com/miekg/unbound
 * https://github.com/miekg/exdns
 * https://dnslookup.org
@@ -55,22 +52,30 @@ A not-so-up-to-date-list-that-may-be-actually-current:
 * https://github.com/mehrdadrad/mylg
 * https://github.com/bamarni/dockness
 * https://github.com/fffaraz/microdns
-* http://kelda.io
 * https://github.com/ipdcode/hades <https://jd.com>
 * https://github.com/StackExchange/dnscontrol/
 * https://www.dnsperf.com/
 * https://dnssectest.net/
-* https://dns.apebits.com
 * https://github.com/oif/apex
 * https://github.com/jedisct1/dnscrypt-proxy
 * https://github.com/jedisct1/rpdns
 * https://github.com/xor-gate/sshfp
 * https://github.com/rs/dnstrace
 * https://blitiri.com.ar/p/dnss ([github mirror](https://github.com/albertito/dnss))
-* https://github.com/semihalev/sdns
 * https://render.com
 * https://github.com/peterzen/goresolver
 * https://github.com/folbricht/routedns
+* https://domainr.com/
+* https://zonedb.org/
+* https://router7.org/
+* https://github.com/fortio/dnsping
+* https://github.com/Luzilla/dnsbl_exporter
+* https://github.com/bodgit/tsig
+* https://github.com/v2fly/v2ray-core (test only)
+* https://kuma.io/
+* https://www.misaka.io/services/dns
+* https://ping.sx/dig
+
 
 Send pull request if you want to be listed here.
 
@@ -167,6 +172,9 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
 * 7873 - Domain Name System (DNS) Cookies
 * 8080 - EdDSA for DNSSEC
 * 8499 - DNS Terminology
+* 8659 - DNS Certification Authority Authorization (CAA) Resource Record
+* 8914 - Extended DNS Errors
+* 8976 - Message Digest for DNS Zones (ZONEMD RR)
 
 ## Loosely Based Upon
 

+ 1 - 0
vendor/github.com/miekg/dns/acceptfunc.go

@@ -25,6 +25,7 @@ var DefaultMsgAcceptFunc MsgAcceptFunc = defaultMsgAcceptFunc
 // MsgAcceptAction represents the action to be taken.
 type MsgAcceptAction int
 
+// Allowed returned values from a MsgAcceptFunc.
 const (
 	MsgAccept               MsgAcceptAction = iota // Accept the message
 	MsgReject                                      // Reject the message with a RcodeFormatError

+ 110 - 45
vendor/github.com/miekg/dns/client.go

@@ -23,6 +23,7 @@ type Conn struct {
 	net.Conn                         // a net.Conn holding the connection
 	UDPSize        uint16            // minimum receive buffer for UDP messages
 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
+	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
 	tsigRequestMAC string
 }
 
@@ -34,12 +35,13 @@ type Client struct {
 	Dialer    *net.Dialer // a net.Dialer used to set local address, timeouts and more
 	// Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout,
 	// WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and
-	// Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext)
+	// Client.Dialer) or context.Context.Deadline (see ExchangeContext)
 	Timeout        time.Duration
 	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
 	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
 	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
+	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
 	SingleInflight bool              // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
 	group          singleflight
 }
@@ -80,6 +82,12 @@ func (c *Client) writeTimeout() time.Duration {
 
 // Dial connects to the address on the named network.
 func (c *Client) Dial(address string) (conn *Conn, err error) {
+	return c.DialContext(context.Background(), address)
+}
+
+// DialContext connects to the address on the named network, with a context.Context.
+// For TLS over TCP (DoT) the context isn't used yet. This will be enabled when Go 1.18 is released.
+func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
 	// create a new dialer with the appropriate timeout
 	var d net.Dialer
 	if c.Dialer == nil {
@@ -99,14 +107,22 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
 	if useTLS {
 		network = strings.TrimSuffix(network, "-tls")
 
+		// TODO(miekg): Enable after Go 1.18 is released, to be able to support two prev. releases.
+		/*
+			tlsDialer := tls.Dialer{
+				NetDialer: &d,
+				Config:    c.TLSConfig,
+			}
+			conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
+		*/
 		conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
 	} else {
-		conn.Conn, err = d.Dial(network, address)
+		conn.Conn, err = d.DialContext(ctx, network, address)
 	}
 	if err != nil {
 		return nil, err
 	}
-
+	conn.UDPSize = c.UDPSize
 	return conn, nil
 }
 
@@ -125,14 +141,46 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
 // To specify a local address or a timeout, the caller has to set the `Client.Dialer`
 // attribute appropriately
 func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
+	co, err := c.Dial(address)
+
+	if err != nil {
+		return nil, 0, err
+	}
+	defer co.Close()
+	return c.ExchangeWithConn(m, co)
+}
+
+// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
+// that will be used instead of creating a new one.
+// Usage pattern with a *dns.Client:
+//
+//	c := new(dns.Client)
+//	// connection management logic goes here
+//
+//	conn := c.Dial(address)
+//	in, rtt, err := c.ExchangeWithConn(message, conn)
+//
+// This allows users of the library to implement their own connection management,
+// as opposed to Exchange, which will always use new connections and incur the added overhead
+// that entails when using "tcp" and especially "tcp-tls" clients.
+//
+// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to
+// prevent one cancelation from canceling all outstanding requests.
+func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
+	return c.exchangeWithConnContext(context.Background(), m, conn)
+}
+
+func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
 	if !c.SingleInflight {
-		return c.exchange(m, address)
+		return c.exchangeContext(ctx, m, conn)
 	}
 
 	q := m.Question[0]
 	key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
 	r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) {
-		return c.exchange(m, address)
+		// When we're doing singleflight we don't want one context cancelation, cancel _all_ outstanding queries.
+		// Hence we ignore the context and use Background().
+		return c.exchangeContext(context.Background(), m, conn)
 	})
 	if r != nil && shared {
 		r = r.Copy()
@@ -141,16 +189,7 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
 	return r, rtt, err
 }
 
-func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
-	var co *Conn
-
-	co, err = c.Dial(a)
-
-	if err != nil {
-		return nil, 0, err
-	}
-	defer co.Close()
-
+func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
 	opt := m.IsEdns0()
 	// If EDNS0 is used use that for size.
 	if opt != nil && opt.UDPSize() >= MinMsgSize {
@@ -161,18 +200,41 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
 		co.UDPSize = c.UDPSize
 	}
 
-	co.TsigSecret = c.TsigSecret
-	t := time.Now()
 	// write with the appropriate write timeout
-	co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout())))
+	t := time.Now()
+	writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
+	readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
+	if deadline, ok := ctx.Deadline(); ok {
+		if deadline.Before(writeDeadline) {
+			writeDeadline = deadline
+		}
+		if deadline.Before(readDeadline) {
+			readDeadline = deadline
+		}
+	}
+	co.SetWriteDeadline(writeDeadline)
+	co.SetReadDeadline(readDeadline)
+
+	co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
+
 	if err = co.WriteMsg(m); err != nil {
 		return nil, 0, err
 	}
 
-	co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout())))
-	r, err = co.ReadMsg()
-	if err == nil && r.Id != m.Id {
-		err = ErrId
+	if _, ok := co.Conn.(net.PacketConn); ok {
+		for {
+			r, err = co.ReadMsg()
+			// Ignore replies with mismatched IDs because they might be
+			// responses to earlier queries that timed out.
+			if err != nil || r.Id == m.Id {
+				break
+			}
+		}
+	} else {
+		r, err = co.ReadMsg()
+		if err == nil && r.Id != m.Id {
+			err = ErrId
+		}
 	}
 	rtt = time.Since(t)
 	return r, rtt, err
@@ -197,11 +259,15 @@ func (co *Conn) ReadMsg() (*Msg, error) {
 		return m, err
 	}
 	if t := m.IsTsig(); t != nil {
-		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
-			return m, ErrSecret
+		if co.TsigProvider != nil {
+			err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false)
+		} else {
+			if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
+				return m, ErrSecret
+			}
+			// Need to work on the original message p, as that was used to calculate the tsig.
+			err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 		}
-		// Need to work on the original message p, as that was used to calculate the tsig.
-		err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 	}
 	return m, err
 }
@@ -279,10 +345,14 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
 	var out []byte
 	if t := m.IsTsig(); t != nil {
 		mac := ""
-		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
-			return ErrSecret
+		if co.TsigProvider != nil {
+			out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false)
+		} else {
+			if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
+				return ErrSecret
+			}
+			out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 		}
-		out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
 		// Set for the next read, although only used in zone transfers
 		co.tsigRequestMAC = mac
 	} else {
@@ -305,11 +375,10 @@ func (co *Conn) Write(p []byte) (int, error) {
 		return co.Conn.Write(p)
 	}
 
-	l := make([]byte, 2)
-	binary.BigEndian.PutUint16(l, uint16(len(p)))
-
-	n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
-	return int(n), err
+	msg := make([]byte, 2+len(p))
+	binary.BigEndian.PutUint16(msg, uint16(len(p)))
+	copy(msg[2:], p)
+	return co.Conn.Write(msg)
 }
 
 // Return the appropriate timeout for a specific request
@@ -345,7 +414,7 @@ func Dial(network, address string) (conn *Conn, err error) {
 func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
 	client := Client{Net: "udp"}
 	r, _, err = client.ExchangeContext(ctx, m, a)
-	// ignorint rtt to leave the original ExchangeContext API unchanged, but
+	// ignoring rtt to leave the original ExchangeContext API unchanged, but
 	// this function will go away
 	return r, err
 }
@@ -401,15 +470,11 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
 // context, if present. If there is both a context deadline and a configured
 // timeout on the client, the earliest of the two takes effect.
 func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
-	var timeout time.Duration
-	if deadline, ok := ctx.Deadline(); !ok {
-		timeout = 0
-	} else {
-		timeout = time.Until(deadline)
+	conn, err := c.DialContext(ctx, a)
+	if err != nil {
+		return nil, 0, err
 	}
-	// not passing the context to the underlying calls, as the API does not support
-	// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
-	// TODO(tmthrgd,miekg): this is a race condition.
-	c.Dialer = &net.Dialer{Timeout: timeout}
-	return c.Exchange(m, a)
+	defer conn.Close()
+
+	return c.exchangeWithConnContext(ctx, m, conn)
 }

+ 1 - 4
vendor/github.com/miekg/dns/defaults.go

@@ -349,10 +349,7 @@ func ReverseAddr(addr string) (arpa string, err error) {
 	// Add it, in reverse, to the buffer
 	for i := len(ip) - 1; i >= 0; i-- {
 		v := ip[i]
-		buf = append(buf, hexDigit[v&0xF])
-		buf = append(buf, '.')
-		buf = append(buf, hexDigit[v>>4])
-		buf = append(buf, '.')
+		buf = append(buf, hexDigit[v&0xF], '.', hexDigit[v>>4], '.')
 	}
 	// Append "ip6.arpa." and return (buf already has the final .)
 	buf = append(buf, "ip6.arpa."...)

+ 27 - 3
vendor/github.com/miekg/dns/dns.go

@@ -1,6 +1,9 @@
 package dns
 
-import "strconv"
+import (
+	"encoding/hex"
+	"strconv"
+)
 
 const (
 	year68     = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
@@ -111,7 +114,7 @@ func (h *RR_Header) parse(c *zlexer, origin string) *ParseError {
 
 // ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
 func (rr *RFC3597) ToRFC3597(r RR) error {
-	buf := make([]byte, Len(r)*2)
+	buf := make([]byte, Len(r))
 	headerEnd, off, err := packRR(r, buf, 0, compressionMap{}, false)
 	if err != nil {
 		return err
@@ -126,9 +129,30 @@ func (rr *RFC3597) ToRFC3597(r RR) error {
 	}
 
 	_, err = rr.unpack(buf, headerEnd)
+	return err
+}
+
+// fromRFC3597 converts an unknown RR representation from RFC 3597 to the known RR type.
+func (rr *RFC3597) fromRFC3597(r RR) error {
+	hdr := r.Header()
+	*hdr = rr.Hdr
+
+	// Can't overflow uint16 as the length of Rdata is validated in (*RFC3597).parse.
+	// We can only get here when rr was constructed with that method.
+	hdr.Rdlength = uint16(hex.DecodedLen(len(rr.Rdata)))
+
+	if noRdata(*hdr) {
+		// Dynamic update.
+		return nil
+	}
+
+	// rr.pack requires an extra allocation and a copy so we just decode Rdata
+	// manually, it's simpler anyway.
+	msg, err := hex.DecodeString(rr.Rdata)
 	if err != nil {
 		return err
 	}
 
-	return nil
+	_, err = r.unpack(msg, 0)
+	return err
 }

+ 18 - 47
vendor/github.com/miekg/dns/dnssec.go

@@ -3,15 +3,14 @@ package dns
 import (
 	"bytes"
 	"crypto"
-	"crypto/dsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
-	_ "crypto/md5"
 	"crypto/rand"
 	"crypto/rsa"
-	_ "crypto/sha1"
-	_ "crypto/sha256"
-	_ "crypto/sha512"
+	_ "crypto/sha1"   // need its init function
+	_ "crypto/sha256" // need its init function
+	_ "crypto/sha512" // need its init function
 	"encoding/asn1"
 	"encoding/binary"
 	"encoding/hex"
@@ -19,8 +18,6 @@ import (
 	"sort"
 	"strings"
 	"time"
-
-	"golang.org/x/crypto/ed25519"
 )
 
 // DNSSEC encryption algorithm codes.
@@ -318,6 +315,7 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
 		}
 
 		rr.Signature = toBase64(signature)
+		return nil
 	case RSAMD5, DSA, DSANSEC3SHA1:
 		// See RFC 6944.
 		return ErrAlg
@@ -332,9 +330,8 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
 		}
 
 		rr.Signature = toBase64(signature)
+		return nil
 	}
-
-	return nil
 }
 
 func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte, error) {
@@ -346,7 +343,6 @@ func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte,
 	switch alg {
 	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 		return signature, nil
-
 	case ECDSAP256SHA256, ECDSAP384SHA384:
 		ecdsaSignature := &struct {
 			R, S *big.Int
@@ -366,25 +362,18 @@ func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte,
 		signature := intToBytes(ecdsaSignature.R, intlen)
 		signature = append(signature, intToBytes(ecdsaSignature.S, intlen)...)
 		return signature, nil
-
-	// There is no defined interface for what a DSA backed crypto.Signer returns
-	case DSA, DSANSEC3SHA1:
-		// 	t := divRoundUp(divRoundUp(p.PublicKey.Y.BitLen(), 8)-64, 8)
-		// 	signature := []byte{byte(t)}
-		// 	signature = append(signature, intToBytes(r1, 20)...)
-		// 	signature = append(signature, intToBytes(s1, 20)...)
-		// 	rr.Signature = signature
-
 	case ED25519:
 		return signature, nil
+	default:
+		return nil, ErrAlg
 	}
-
-	return nil, ErrAlg
 }
 
 // Verify validates an RRSet with the signature and key. This is only the
 // cryptographic test, the signature validity period must be checked separately.
 // This function copies the rdata of some RRs (to lowercase domain names) for the validation to work.
+// It also checks that the Zone Key bit (RFC 4034 2.1.1) is set on the DNSKEY
+// and that the Protocol field is set to 3 (RFC 4034 2.1.2).
 func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 	// First the easy checks
 	if !IsRRset(rrset) {
@@ -405,6 +394,12 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 	if k.Protocol != 3 {
 		return ErrKey
 	}
+	// RFC 4034 2.1.1 If bit 7 has value 0, then the DNSKEY record holds some
+	// other type of DNS public key and MUST NOT be used to verify RRSIGs that
+	// cover RRsets.
+	if k.Flags&ZONE == 0 {
+		return ErrKey
+	}
 
 	// IsRRset checked that we have at least one RR and that the RRs in
 	// the set have consistent type, class, and name. Also check that type and
@@ -448,7 +443,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
 	}
 
 	switch rr.Algorithm {
-	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512, RSAMD5:
+	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 		// TODO(mg): this can be done quicker, ie. cache the pubkey data somewhere??
 		pubkey := k.publicKeyRSA() // Get the key
 		if pubkey == nil {
@@ -512,7 +507,7 @@ func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
 	return ti <= utc && utc <= te
 }
 
-// Return the signatures base64 encodedig sigdata as a byte slice.
+// Return the signatures base64 encoding sigdata as a byte slice.
 func (rr *RRSIG) sigBuf() []byte {
 	sigbuf, err := fromBase64([]byte(rr.Signature))
 	if err != nil {
@@ -600,30 +595,6 @@ func (k *DNSKEY) publicKeyECDSA() *ecdsa.PublicKey {
 	return pubkey
 }
 
-func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey {
-	keybuf, err := fromBase64([]byte(k.PublicKey))
-	if err != nil {
-		return nil
-	}
-	if len(keybuf) < 22 {
-		return nil
-	}
-	t, keybuf := int(keybuf[0]), keybuf[1:]
-	size := 64 + t*8
-	q, keybuf := keybuf[:20], keybuf[20:]
-	if len(keybuf) != 3*size {
-		return nil
-	}
-	p, keybuf := keybuf[:size], keybuf[size:]
-	g, y := keybuf[:size], keybuf[size:]
-	pubkey := new(dsa.PublicKey)
-	pubkey.Parameters.Q = new(big.Int).SetBytes(q)
-	pubkey.Parameters.P = new(big.Int).SetBytes(p)
-	pubkey.Parameters.G = new(big.Int).SetBytes(g)
-	pubkey.Y = new(big.Int).SetBytes(y)
-	return pubkey
-}
-
 func (k *DNSKEY) publicKeyED25519() ed25519.PublicKey {
 	keybuf, err := fromBase64([]byte(k.PublicKey))
 	if err != nil {

+ 3 - 4
vendor/github.com/miekg/dns/dnssec_keygen.go

@@ -3,12 +3,11 @@ package dns
 import (
 	"crypto"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rsa"
 	"math/big"
-
-	"golang.org/x/crypto/ed25519"
 )
 
 // Generate generates a DNSKEY of the given bit size.
@@ -19,8 +18,6 @@ import (
 // bits should be set to the size of the algorithm.
 func (k *DNSKEY) Generate(bits int) (crypto.PrivateKey, error) {
 	switch k.Algorithm {
-	case RSAMD5, DSA, DSANSEC3SHA1:
-		return nil, ErrAlg
 	case RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
 		if bits < 512 || bits > 4096 {
 			return nil, ErrKeySize
@@ -41,6 +38,8 @@ func (k *DNSKEY) Generate(bits int) (crypto.PrivateKey, error) {
 		if bits != 256 {
 			return nil, ErrKeySize
 		}
+	default:
+		return nil, ErrAlg
 	}
 
 	switch k.Algorithm {

+ 4 - 17
vendor/github.com/miekg/dns/dnssec_keyscan.go

@@ -4,13 +4,12 @@ import (
 	"bufio"
 	"crypto"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/rsa"
 	"io"
 	"math/big"
 	"strconv"
 	"strings"
-
-	"golang.org/x/crypto/ed25519"
 )
 
 // NewPrivateKey returns a PrivateKey by parsing the string s.
@@ -43,15 +42,7 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er
 		return nil, ErrPrivKey
 	}
 	switch uint8(algo) {
-	case RSAMD5, DSA, DSANSEC3SHA1:
-		return nil, ErrAlg
-	case RSASHA1:
-		fallthrough
-	case RSASHA1NSEC3SHA1:
-		fallthrough
-	case RSASHA256:
-		fallthrough
-	case RSASHA512:
+	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
 		priv, err := readPrivateKeyRSA(m)
 		if err != nil {
 			return nil, err
@@ -62,11 +53,7 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er
 		}
 		priv.PublicKey = *pub
 		return priv, nil
-	case ECCGOST:
-		return nil, ErrPrivKey
-	case ECDSAP256SHA256:
-		fallthrough
-	case ECDSAP384SHA384:
+	case ECDSAP256SHA256, ECDSAP384SHA384:
 		priv, err := readPrivateKeyECDSA(m)
 		if err != nil {
 			return nil, err
@@ -80,7 +67,7 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er
 	case ED25519:
 		return readPrivateKeyED25519(m)
 	default:
-		return nil, ErrPrivKey
+		return nil, ErrAlg
 	}
 }
 

+ 3 - 20
vendor/github.com/miekg/dns/dnssec_privkey.go

@@ -2,13 +2,11 @@ package dns
 
 import (
 	"crypto"
-	"crypto/dsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/rsa"
 	"math/big"
 	"strconv"
-
-	"golang.org/x/crypto/ed25519"
 )
 
 const format = "Private-key-format: v1.3\n"
@@ -17,8 +15,8 @@ var bigIntOne = big.NewInt(1)
 
 // PrivateKeyString converts a PrivateKey to a string. This string has the same
 // format as the private-key-file of BIND9 (Private-key-format: v1.3).
-// It needs some info from the key (the algorithm), so its a method of the DNSKEY
-// It supports rsa.PrivateKey, ecdsa.PrivateKey and dsa.PrivateKey
+// It needs some info from the key (the algorithm), so its a method of the DNSKEY.
+// It supports *rsa.PrivateKey, *ecdsa.PrivateKey and ed25519.PrivateKey.
 func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string {
 	algorithm := strconv.Itoa(int(r.Algorithm))
 	algorithm += " (" + AlgorithmToString[r.Algorithm] + ")"
@@ -67,21 +65,6 @@ func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string {
 			"Algorithm: " + algorithm + "\n" +
 			"PrivateKey: " + private + "\n"
 
-	case *dsa.PrivateKey:
-		T := divRoundUp(divRoundUp(p.PublicKey.Parameters.G.BitLen(), 8)-64, 8)
-		prime := toBase64(intToBytes(p.PublicKey.Parameters.P, 64+T*8))
-		subprime := toBase64(intToBytes(p.PublicKey.Parameters.Q, 20))
-		base := toBase64(intToBytes(p.PublicKey.Parameters.G, 64+T*8))
-		priv := toBase64(intToBytes(p.X, 20))
-		pub := toBase64(intToBytes(p.PublicKey.Y, 64+T*8))
-		return format +
-			"Algorithm: " + algorithm + "\n" +
-			"Prime(p): " + prime + "\n" +
-			"Subprime(q): " + subprime + "\n" +
-			"Base(g): " + base + "\n" +
-			"Private_value(x): " + priv + "\n" +
-			"Public_value(y): " + pub + "\n"
-
 	case ed25519.PrivateKey:
 		private := toBase64(p.Seed())
 		return format +

+ 29 - 5
vendor/github.com/miekg/dns/doc.go

@@ -159,7 +159,7 @@ shows the options you have and what functions to call.
 TRANSACTION SIGNATURE
 
 An TSIG or transaction signature adds a HMAC TSIG record to each message sent.
-The supported algorithms include: HmacMD5, HmacSHA1, HmacSHA256 and HmacSHA512.
+The supported algorithms include: HmacSHA1, HmacSHA256 and HmacSHA512.
 
 Basic use pattern when querying with a TSIG name "axfr." (note that these key names
 must be fully qualified - as they are domain names) and the base64 secret
@@ -174,7 +174,7 @@ changes to the RRset after calling SetTsig() the signature will be incorrect.
 	c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
 	m := new(dns.Msg)
 	m.SetQuestion("miek.nl.", dns.TypeMX)
-	m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
+	m.SetTsig("axfr.", dns.HmacSHA256, 300, time.Now().Unix())
 	...
 	// When sending the TSIG RR is calculated and filled in before sending
 
@@ -187,13 +187,37 @@ request an AXFR for miek.nl. with TSIG key named "axfr." and secret
 	m := new(dns.Msg)
 	t.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
 	m.SetAxfr("miek.nl.")
-	m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
+	m.SetTsig("axfr.", dns.HmacSHA256, 300, time.Now().Unix())
 	c, err := t.In(m, "176.58.119.54:53")
 	for r := range c { ... }
 
 You can now read the records from the transfer as they come in. Each envelope
 is checked with TSIG. If something is not correct an error is returned.
 
+A custom TSIG implementation can be used. This requires additional code to
+perform any session establishment and signature generation/verification. The
+client must be configured with an implementation of the TsigProvider interface:
+
+	type Provider struct{}
+
+	func (*Provider) Generate(msg []byte, tsig *dns.TSIG) ([]byte, error) {
+		// Use tsig.Hdr.Name and tsig.Algorithm in your code to
+		// generate the MAC using msg as the payload.
+	}
+
+	func (*Provider) Verify(msg []byte, tsig *dns.TSIG) error {
+		// Use tsig.Hdr.Name and tsig.Algorithm in your code to verify
+		// that msg matches the value in tsig.MAC.
+	}
+
+	c := new(dns.Client)
+	c.TsigProvider = new(Provider)
+	m := new(dns.Msg)
+	m.SetQuestion("miek.nl.", dns.TypeMX)
+	m.SetTsig(keyname, dns.HmacSHA256, 300, time.Now().Unix())
+	...
+	// TSIG RR is calculated by calling your Generate method
+
 Basic use pattern validating and replying to a message that has TSIG set.
 
 	server := &dns.Server{Addr: ":53", Net: "udp"}
@@ -207,7 +231,7 @@ Basic use pattern validating and replying to a message that has TSIG set.
 		if r.IsTsig() != nil {
 			if w.TsigStatus() == nil {
 				// *Msg r has an TSIG record and it was validated
-				m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix())
+				m.SetTsig("axfr.", dns.HmacSHA256, 300, time.Now().Unix())
 			} else {
 				// *Msg r has an TSIG records and it was not validated
 			}
@@ -260,7 +284,7 @@ From RFC 2931:
     on requests and responses, and protection of the overall integrity of a response.
 
 It works like TSIG, except that SIG(0) uses public key cryptography, instead of
-the shared secret approach in TSIG. Supported algorithms: DSA, ECDSAP256SHA256,
+the shared secret approach in TSIG. Supported algorithms: ECDSAP256SHA256,
 ECDSAP384SHA384, RSASHA1, RSASHA256 and RSASHA512.
 
 Signing subsequent messages in multi-message sessions is not implemented.

+ 12 - 4
vendor/github.com/miekg/dns/duplicate_generate.go

@@ -30,6 +30,9 @@ func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
 	if !ok {
 		return nil, false
 	}
+	if st.NumFields() == 0 {
+		return nil, false
+	}
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 		return st, false
 	}
@@ -83,10 +86,7 @@ func main() {
 	for _, name := range namedTypes {
 
 		o := scope.Lookup(name)
-		st, isEmbedded := getTypeStruct(o.Type(), scope)
-		if isEmbedded {
-			continue
-		}
+		st, _ := getTypeStruct(o.Type(), scope)
 		fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name)
 		fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name)
 		fmt.Fprint(b, "if !ok { return false }\n")
@@ -121,6 +121,14 @@ func main() {
 					continue
 				}
 
+				if st.Tag(i) == `dns:"pairs"` {
+					o2(`if !areSVCBPairArraysEqual(r1.%s, r2.%s) {
+						return false
+					}`)
+
+					continue
+				}
+
 				o3(`for i := 0; i < len(r1.%s); i++ {
 					if r1.%s[i] != r2.%s[i] {
 						return false

+ 151 - 5
vendor/github.com/miekg/dns/edns.go

@@ -22,11 +22,47 @@ const (
 	EDNS0COOKIE       = 0xa     // EDNS0 Cookie
 	EDNS0TCPKEEPALIVE = 0xb     // EDNS0 tcp keep alive (See RFC 7828)
 	EDNS0PADDING      = 0xc     // EDNS0 padding (See RFC 7830)
+	EDNS0EDE          = 0xf     // EDNS0 extended DNS errors (See RFC 8914)
 	EDNS0LOCALSTART   = 0xFDE9  // Beginning of range reserved for local/experimental use (See RFC 6891)
 	EDNS0LOCALEND     = 0xFFFE  // End of range reserved for local/experimental use (See RFC 6891)
 	_DO               = 1 << 15 // DNSSEC OK
 )
 
+// makeDataOpt is used to unpack the EDNS0 option(s) from a message.
+func makeDataOpt(code uint16) EDNS0 {
+	// All the EDNS0.* constants above need to be in this switch.
+	switch code {
+	case EDNS0LLQ:
+		return new(EDNS0_LLQ)
+	case EDNS0UL:
+		return new(EDNS0_UL)
+	case EDNS0NSID:
+		return new(EDNS0_NSID)
+	case EDNS0DAU:
+		return new(EDNS0_DAU)
+	case EDNS0DHU:
+		return new(EDNS0_DHU)
+	case EDNS0N3U:
+		return new(EDNS0_N3U)
+	case EDNS0SUBNET:
+		return new(EDNS0_SUBNET)
+	case EDNS0EXPIRE:
+		return new(EDNS0_EXPIRE)
+	case EDNS0COOKIE:
+		return new(EDNS0_COOKIE)
+	case EDNS0TCPKEEPALIVE:
+		return new(EDNS0_TCP_KEEPALIVE)
+	case EDNS0PADDING:
+		return new(EDNS0_PADDING)
+	case EDNS0EDE:
+		return new(EDNS0_EDE)
+	default:
+		e := new(EDNS0_LOCAL)
+		e.Code = code
+		return e
+	}
+}
+
 // OPT is the EDNS0 RR appended to messages to convey extra (meta) information.
 // See RFC 6891.
 type OPT struct {
@@ -73,6 +109,8 @@ func (rr *OPT) String() string {
 			s += "\n; LOCAL OPT: " + o.String()
 		case *EDNS0_PADDING:
 			s += "\n; PADDING: " + o.String()
+		case *EDNS0_EDE:
+			s += "\n; EDE: " + o.String()
 		}
 	}
 	return s
@@ -88,11 +126,11 @@ func (rr *OPT) len(off int, compression map[string]struct{}) int {
 	return l
 }
 
-func (rr *OPT) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on OPT")
+func (*OPT) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "OPT records do not have a presentation format"}
 }
 
-func (r1 *OPT) isDuplicate(r2 RR) bool { return false }
+func (rr *OPT) isDuplicate(r2 RR) bool { return false }
 
 // return the old value -> delete SetVersion?
 
@@ -148,6 +186,16 @@ func (rr *OPT) SetDo(do ...bool) {
 	}
 }
 
+// Z returns the Z part of the OPT RR as a uint16 with only the 15 least significant bits used.
+func (rr *OPT) Z() uint16 {
+	return uint16(rr.Hdr.Ttl & 0x7FFF)
+}
+
+// SetZ sets the Z part of the OPT RR, note only the 15 least significant bits of z are used.
+func (rr *OPT) SetZ(z uint16) {
+	rr.Hdr.Ttl = rr.Hdr.Ttl&^0x7FFF | uint32(z&0x7FFF)
+}
+
 // EDNS0 defines an EDNS0 Option. An OPT RR can have multiple options appended to it.
 type EDNS0 interface {
 	// Option returns the option code for the option.
@@ -452,7 +500,7 @@ func (e *EDNS0_LLQ) copy() EDNS0 {
 	return &EDNS0_LLQ{e.Code, e.Version, e.Opcode, e.Error, e.Id, e.LeaseLife}
 }
 
-// EDNS0_DUA implements the EDNS0 "DNSSEC Algorithm Understood" option. See RFC 6975.
+// EDNS0_DAU implements the EDNS0 "DNSSEC Algorithm Understood" option. See RFC 6975.
 type EDNS0_DAU struct {
 	Code    uint16 // Always EDNS0DAU
 	AlgCode []uint8
@@ -525,7 +573,7 @@ func (e *EDNS0_N3U) String() string {
 }
 func (e *EDNS0_N3U) copy() EDNS0 { return &EDNS0_N3U{e.Code, e.AlgCode} }
 
-// EDNS0_EXPIRE implementes the EDNS0 option as described in RFC 7314.
+// EDNS0_EXPIRE implements the EDNS0 option as described in RFC 7314.
 type EDNS0_EXPIRE struct {
 	Code   uint16 // Always EDNS0EXPIRE
 	Expire uint32
@@ -673,3 +721,101 @@ func (e *EDNS0_PADDING) copy() EDNS0 {
 	copy(b, e.Padding)
 	return &EDNS0_PADDING{b}
 }
+
+// Extended DNS Error Codes (RFC 8914).
+const (
+	ExtendedErrorCodeOther uint16 = iota
+	ExtendedErrorCodeUnsupportedDNSKEYAlgorithm
+	ExtendedErrorCodeUnsupportedDSDigestType
+	ExtendedErrorCodeStaleAnswer
+	ExtendedErrorCodeForgedAnswer
+	ExtendedErrorCodeDNSSECIndeterminate
+	ExtendedErrorCodeDNSBogus
+	ExtendedErrorCodeSignatureExpired
+	ExtendedErrorCodeSignatureNotYetValid
+	ExtendedErrorCodeDNSKEYMissing
+	ExtendedErrorCodeRRSIGsMissing
+	ExtendedErrorCodeNoZoneKeyBitSet
+	ExtendedErrorCodeNSECMissing
+	ExtendedErrorCodeCachedError
+	ExtendedErrorCodeNotReady
+	ExtendedErrorCodeBlocked
+	ExtendedErrorCodeCensored
+	ExtendedErrorCodeFiltered
+	ExtendedErrorCodeProhibited
+	ExtendedErrorCodeStaleNXDOMAINAnswer
+	ExtendedErrorCodeNotAuthoritative
+	ExtendedErrorCodeNotSupported
+	ExtendedErrorCodeNoReachableAuthority
+	ExtendedErrorCodeNetworkError
+	ExtendedErrorCodeInvalidData
+)
+
+// ExtendedErrorCodeToString maps extended error info codes to a human readable
+// description.
+var ExtendedErrorCodeToString = map[uint16]string{
+	ExtendedErrorCodeOther:                      "Other",
+	ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: "Unsupported DNSKEY Algorithm",
+	ExtendedErrorCodeUnsupportedDSDigestType:    "Unsupported DS Digest Type",
+	ExtendedErrorCodeStaleAnswer:                "Stale Answer",
+	ExtendedErrorCodeForgedAnswer:               "Forged Answer",
+	ExtendedErrorCodeDNSSECIndeterminate:        "DNSSEC Indeterminate",
+	ExtendedErrorCodeDNSBogus:                   "DNSSEC Bogus",
+	ExtendedErrorCodeSignatureExpired:           "Signature Expired",
+	ExtendedErrorCodeSignatureNotYetValid:       "Signature Not Yet Valid",
+	ExtendedErrorCodeDNSKEYMissing:              "DNSKEY Missing",
+	ExtendedErrorCodeRRSIGsMissing:              "RRSIGs Missing",
+	ExtendedErrorCodeNoZoneKeyBitSet:            "No Zone Key Bit Set",
+	ExtendedErrorCodeNSECMissing:                "NSEC Missing",
+	ExtendedErrorCodeCachedError:                "Cached Error",
+	ExtendedErrorCodeNotReady:                   "Not Ready",
+	ExtendedErrorCodeBlocked:                    "Blocked",
+	ExtendedErrorCodeCensored:                   "Censored",
+	ExtendedErrorCodeFiltered:                   "Filtered",
+	ExtendedErrorCodeProhibited:                 "Prohibited",
+	ExtendedErrorCodeStaleNXDOMAINAnswer:        "Stale NXDOMAIN Answer",
+	ExtendedErrorCodeNotAuthoritative:           "Not Authoritative",
+	ExtendedErrorCodeNotSupported:               "Not Supported",
+	ExtendedErrorCodeNoReachableAuthority:       "No Reachable Authority",
+	ExtendedErrorCodeNetworkError:               "Network Error",
+	ExtendedErrorCodeInvalidData:                "Invalid Data",
+}
+
+// StringToExtendedErrorCode is a map from human readable descriptions to
+// extended error info codes.
+var StringToExtendedErrorCode = reverseInt16(ExtendedErrorCodeToString)
+
+// EDNS0_EDE option is used to return additional information about the cause of
+// DNS errors.
+type EDNS0_EDE struct {
+	InfoCode  uint16
+	ExtraText string
+}
+
+// Option implements the EDNS0 interface.
+func (e *EDNS0_EDE) Option() uint16 { return EDNS0EDE }
+func (e *EDNS0_EDE) copy() EDNS0    { return &EDNS0_EDE{e.InfoCode, e.ExtraText} }
+
+func (e *EDNS0_EDE) String() string {
+	info := strconv.FormatUint(uint64(e.InfoCode), 10)
+	if s, ok := ExtendedErrorCodeToString[e.InfoCode]; ok {
+		info += fmt.Sprintf(" (%s)", s)
+	}
+	return fmt.Sprintf("%s: (%s)", info, e.ExtraText)
+}
+
+func (e *EDNS0_EDE) pack() ([]byte, error) {
+	b := make([]byte, 2+len(e.ExtraText))
+	binary.BigEndian.PutUint16(b[0:], e.InfoCode)
+	copy(b[2:], []byte(e.ExtraText))
+	return b, nil
+}
+
+func (e *EDNS0_EDE) unpack(b []byte) error {
+	if len(b) < 2 {
+		return ErrBuf
+	}
+	e.InfoCode = binary.BigEndian.Uint16(b[0:])
+	e.ExtraText = string(b[2:])
+	return nil
+}

+ 12 - 12
vendor/github.com/miekg/dns/generate.go

@@ -20,13 +20,13 @@ import (
 // of $ after that are interpreted.
 func (zp *ZoneParser) generate(l lex) (RR, bool) {
 	token := l.token
-	step := 1
+	step := int64(1)
 	if i := strings.IndexByte(token, '/'); i >= 0 {
 		if i+1 == len(token) {
 			return zp.setParseError("bad step in $GENERATE range", l)
 		}
 
-		s, err := strconv.Atoi(token[i+1:])
+		s, err := strconv.ParseInt(token[i+1:], 10, 64)
 		if err != nil || s <= 0 {
 			return zp.setParseError("bad step in $GENERATE range", l)
 		}
@@ -40,12 +40,12 @@ func (zp *ZoneParser) generate(l lex) (RR, bool) {
 		return zp.setParseError("bad start-stop in $GENERATE range", l)
 	}
 
-	start, err := strconv.Atoi(sx[0])
+	start, err := strconv.ParseInt(sx[0], 10, 64)
 	if err != nil {
 		return zp.setParseError("bad start in $GENERATE range", l)
 	}
 
-	end, err := strconv.Atoi(sx[1])
+	end, err := strconv.ParseInt(sx[1], 10, 64)
 	if err != nil {
 		return zp.setParseError("bad stop in $GENERATE range", l)
 	}
@@ -94,10 +94,10 @@ type generateReader struct {
 	s  string
 	si int
 
-	cur   int
-	start int
-	end   int
-	step  int
+	cur   int64
+	start int64
+	end   int64
+	step  int64
 
 	mod bytes.Buffer
 
@@ -173,7 +173,7 @@ func (r *generateReader) ReadByte() (byte, error) {
 			return '$', nil
 		}
 
-		var offset int
+		var offset int64
 
 		// Search for { and }
 		if r.s[si+1] == '{' {
@@ -208,7 +208,7 @@ func (r *generateReader) ReadByte() (byte, error) {
 }
 
 // Convert a $GENERATE modifier 0,0,d to something Printf can deal with.
-func modToPrintf(s string) (string, int, string) {
+func modToPrintf(s string) (string, int64, string) {
 	// Modifier is { offset [ ,width [ ,base ] ] } - provide default
 	// values for optional width and type, if necessary.
 	var offStr, widthStr, base string
@@ -229,12 +229,12 @@ func modToPrintf(s string) (string, int, string) {
 		return "", 0, "bad base in $GENERATE"
 	}
 
-	offset, err := strconv.Atoi(offStr)
+	offset, err := strconv.ParseInt(offStr, 10, 64)
 	if err != nil {
 		return "", 0, "bad offset in $GENERATE"
 	}
 
-	width, err := strconv.Atoi(widthStr)
+	width, err := strconv.ParseInt(widthStr, 10, 64)
 	if err != nil || width < 0 || width > 255 {
 		return "", 0, "bad width in $GENERATE"
 	}

+ 4 - 6
vendor/github.com/miekg/dns/go.mod

@@ -1,11 +1,9 @@
 module github.com/miekg/dns
 
-go 1.12
+go 1.14
 
 require (
-	golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
-	golang.org/x/net v0.0.0-20190923162816-aa69164e4478
-	golang.org/x/sync v0.0.0-20190423024810-112230192c58
-	golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe
-	golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 // indirect
+	golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
+	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
+	golang.org/x/sys v0.0.0-20210303074136-134d130e1a04
 )

+ 9 - 38
vendor/github.com/miekg/dns/go.sum

@@ -1,39 +1,10 @@
-golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 h1:Vk3wNqEZwyGyei9yq5ekj7frek2u7HUfffJ1/opblzc=
-golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM=
-golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392 h1:ACG4HJsFiNMf47Y4PeRoebLNy/2lXT9EtprMuTFWt1M=
-golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/net v0.0.0-20180926154720-4dfa2610cdf3 h1:dgd4x4kJt7G4k4m93AYLzM8Ni6h2qLTfh9n9vXJT3/0=
-golang.org/x/net v0.0.0-20180926154720-4dfa2610cdf3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM=
-golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g=
-golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
-golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611 h1:O33LKL7WyJgjN9CvxfTIomjIClbd/Kq86/iipowHQU0=
-golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190904154756-749cb33beabd h1:DBH9mDw0zluJT/R+nGuV3jWFWLFaHyYZWD4tOT+cjn0=
-golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe h1:6fAMxZRR6sl1Uq8U61gxU+kPTs2tR8uOySCbBP7BN/M=
-golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g=
+golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA=
-golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

+ 1 - 1
vendor/github.com/miekg/dns/labels.go

@@ -10,7 +10,7 @@ package dns
 // escaped dots (\.) for instance.
 // s must be a syntactically valid domain name, see IsDomainName.
 func SplitDomainName(s string) (labels []string) {
-	if len(s) == 0 {
+	if s == "" {
 		return nil
 	}
 	fqdnEnd := 0 // offset of the final '.' or the length of the name

+ 0 - 0
vendor/github.com/miekg/dns/listen_go_not111.go → vendor/github.com/miekg/dns/listen_no_reuseport.go


+ 0 - 0
vendor/github.com/miekg/dns/listen_go111.go → vendor/github.com/miekg/dns/listen_reuseport.go


+ 15 - 13
vendor/github.com/miekg/dns/msg.go

@@ -398,17 +398,12 @@ Loop:
 				return "", lenmsg, ErrLongDomain
 			}
 			for _, b := range msg[off : off+c] {
-				switch b {
-				case '.', '(', ')', ';', ' ', '@':
-					fallthrough
-				case '"', '\\':
+				if isDomainNameLabelSpecial(b) {
 					s = append(s, '\\', b)
-				default:
-					if b < ' ' || b > '~' { // unprintable, use \DDD
-						s = append(s, escapeByte(b)...)
-					} else {
-						s = append(s, b)
-					}
+				} else if b < ' ' || b > '~' {
+					s = append(s, escapeByte(b)...)
+				} else {
+					s = append(s, b)
 				}
 			}
 			s = append(s, '.')
@@ -629,11 +624,18 @@ func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err
 		rr = &RFC3597{Hdr: h}
 	}
 
-	if noRdata(h) {
-		return rr, off, nil
+	if off < 0 || off > len(msg) {
+		return &h, off, &Error{err: "bad off"}
 	}
 
 	end := off + int(h.Rdlength)
+	if end < off || end > len(msg) {
+		return &h, end, &Error{err: "bad rdlength"}
+	}
+
+	if noRdata(h) {
+		return rr, off, nil
+	}
 
 	off, err = rr.unpack(msg, off)
 	if err != nil {
@@ -740,7 +742,7 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compression
 	}
 
 	// Set extended rcode unconditionally if we have an opt, this will allow
-	// reseting the extended rcode bits if they need to.
+	// resetting the extended rcode bits if they need to.
 	if opt := dns.IsEdns0(); opt != nil {
 		opt.SetExtendedRcode(uint16(dns.Rcode))
 	} else if dns.Rcode > 0xF {

+ 7 - 0
vendor/github.com/miekg/dns/msg_generate.go

@@ -35,6 +35,9 @@ func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
 	if !ok {
 		return nil, false
 	}
+	if st.NumFields() == 0 {
+		return nil, false
+	}
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 		return st, false
 	}
@@ -110,6 +113,8 @@ return off, err
 					o("off, err = packDataOpt(rr.%s, msg, off)\n")
 				case `dns:"nsec"`:
 					o("off, err = packDataNsec(rr.%s, msg, off)\n")
+				case `dns:"pairs"`:
+					o("off, err = packDataSVCB(rr.%s, msg, off)\n")
 				case `dns:"domain-name"`:
 					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
 				case `dns:"apl"`:
@@ -236,6 +241,8 @@ return off, err
 					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
 				case `dns:"nsec"`:
 					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
+				case `dns:"pairs"`:
+					o("rr.%s, off, err = unpackDataSVCB(msg, off)\n")
 				case `dns:"domain-name"`:
 					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
 				case `dns:"apl"`:

+ 73 - 82
vendor/github.com/miekg/dns/msg_helpers.go

@@ -6,6 +6,7 @@ import (
 	"encoding/binary"
 	"encoding/hex"
 	"net"
+	"sort"
 	"strings"
 )
 
@@ -423,86 +424,12 @@ Option:
 	if off+int(optlen) > len(msg) {
 		return nil, len(msg), &Error{err: "overflow unpacking opt"}
 	}
-	switch code {
-	case EDNS0NSID:
-		e := new(EDNS0_NSID)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0SUBNET:
-		e := new(EDNS0_SUBNET)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0COOKIE:
-		e := new(EDNS0_COOKIE)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0EXPIRE:
-		e := new(EDNS0_EXPIRE)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0UL:
-		e := new(EDNS0_UL)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0LLQ:
-		e := new(EDNS0_LLQ)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0DAU:
-		e := new(EDNS0_DAU)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0DHU:
-		e := new(EDNS0_DHU)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0N3U:
-		e := new(EDNS0_N3U)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	case EDNS0PADDING:
-		e := new(EDNS0_PADDING)
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
-	default:
-		e := new(EDNS0_LOCAL)
-		e.Code = code
-		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
-			return nil, len(msg), err
-		}
-		edns = append(edns, e)
-		off += int(optlen)
+	e := makeDataOpt(code)
+	if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
+		return nil, len(msg), err
 	}
+	edns = append(edns, e)
+	off += int(optlen)
 
 	if off < len(msg) {
 		goto Option
@@ -521,9 +448,7 @@ func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
 		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
 		off += 4
 		if off+len(b) > len(msg) {
-			copy(msg[off:], b)
-			off = len(msg)
-			continue
+			return len(msg), &Error{err: "overflow packing opt"}
 		}
 		// Actual data
 		copy(msg[off:off+len(b)], b)
@@ -659,6 +584,65 @@ func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
 	return off, nil
 }
 
+func unpackDataSVCB(msg []byte, off int) ([]SVCBKeyValue, int, error) {
+	var xs []SVCBKeyValue
+	var code uint16
+	var length uint16
+	var err error
+	for off < len(msg) {
+		code, off, err = unpackUint16(msg, off)
+		if err != nil {
+			return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
+		}
+		length, off, err = unpackUint16(msg, off)
+		if err != nil || off+int(length) > len(msg) {
+			return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
+		}
+		e := makeSVCBKeyValue(SVCBKey(code))
+		if e == nil {
+			return nil, len(msg), &Error{err: "bad SVCB key"}
+		}
+		if err := e.unpack(msg[off : off+int(length)]); err != nil {
+			return nil, len(msg), err
+		}
+		if len(xs) > 0 && e.Key() <= xs[len(xs)-1].Key() {
+			return nil, len(msg), &Error{err: "SVCB keys not in strictly increasing order"}
+		}
+		xs = append(xs, e)
+		off += int(length)
+	}
+	return xs, off, nil
+}
+
+func packDataSVCB(pairs []SVCBKeyValue, msg []byte, off int) (int, error) {
+	pairs = append([]SVCBKeyValue(nil), pairs...)
+	sort.Slice(pairs, func(i, j int) bool {
+		return pairs[i].Key() < pairs[j].Key()
+	})
+	prev := svcb_RESERVED
+	for _, el := range pairs {
+		if el.Key() == prev {
+			return len(msg), &Error{err: "repeated SVCB keys are not allowed"}
+		}
+		prev = el.Key()
+		packed, err := el.pack()
+		if err != nil {
+			return len(msg), err
+		}
+		off, err = packUint16(uint16(el.Key()), msg, off)
+		if err != nil {
+			return len(msg), &Error{err: "overflow packing SVCB"}
+		}
+		off, err = packUint16(uint16(len(packed)), msg, off)
+		if err != nil || off+len(packed) > len(msg) {
+			return len(msg), &Error{err: "overflow packing SVCB"}
+		}
+		copy(msg[off:off+len(packed)], packed)
+		off += len(packed)
+	}
+	return off, nil
+}
+
 func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
 	var (
 		servers []string
@@ -730,6 +714,13 @@ func packDataAplPrefix(p *APLPrefix, msg []byte, off int) (int, error) {
 	if p.Negation {
 		n = 0x80
 	}
+
+	// trim trailing zero bytes as specified in RFC3123 Sections 4.1 and 4.2.
+	i := len(addr) - 1
+	for ; i >= 0 && addr[i] == 0; i-- {
+	}
+	addr = addr[:i+1]
+
 	adflen := uint8(len(addr)) & 0x7f
 	off, err = packUint8(n|adflen, msg, off)
 	if err != nil {

+ 11 - 5
vendor/github.com/miekg/dns/msg_truncate.go

@@ -8,8 +8,14 @@ package dns
 // record adding as many records as possible without exceeding the
 // requested buffer size.
 //
+// If the message fits within the requested size without compression,
+// Truncate will set the message's Compress attribute to false. It is
+// the caller's responsibility to set it back to true if they wish to
+// compress the payload regardless of size.
+//
 // The TC bit will be set if any records were excluded from the message.
-// This indicates to that the client should retry over TCP.
+// If the TC bit is already set on the message it will be retained.
+// TC indicates that the client should retry over TCP.
 //
 // According to RFC 2181, the TC bit should only be set if not all of the
 // "required" RRs can be included in the response. Unfortunately, we have
@@ -28,11 +34,11 @@ func (dns *Msg) Truncate(size int) {
 	}
 
 	// RFC 6891 mandates that the payload size in an OPT record
-	// less than 512 bytes must be treated as equal to 512 bytes.
+	// less than 512 (MinMsgSize) bytes must be treated as equal to 512 bytes.
 	//
 	// For ease of use, we impose that restriction here.
-	if size < 512 {
-		size = 512
+	if size < MinMsgSize {
+		size = MinMsgSize
 	}
 
 	l := msgLenWithCompressionMap(dns, nil) // uncompressed length
@@ -77,7 +83,7 @@ func (dns *Msg) Truncate(size int) {
 	}
 
 	// See the function documentation for when we set this.
-	dns.Truncated = len(dns.Answer) > numAnswer ||
+	dns.Truncated = dns.Truncated || len(dns.Answer) > numAnswer ||
 		len(dns.Ns) > numNS || len(dns.Extra) > numExtra
 
 	dns.Answer = dns.Answer[:numAnswer]

+ 2 - 2
vendor/github.com/miekg/dns/privaterr.go

@@ -6,7 +6,7 @@ import "strings"
 // RFC 6895. This allows one to experiment with new RR types, without requesting an
 // official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove.
 type PrivateRdata interface {
-	// String returns the text presentaton of the Rdata of the Private RR.
+	// String returns the text presentation of the Rdata of the Private RR.
 	String() string
 	// Parse parses the Rdata of the private RR.
 	Parse([]string) error
@@ -90,7 +90,7 @@ Fetch:
 	return nil
 }
 
-func (r1 *PrivateRR) isDuplicate(r2 RR) bool { return false }
+func (r *PrivateRR) isDuplicate(r2 RR) bool { return false }
 
 // PrivateHandle registers a private resource record type. It requires
 // string and numeric representation of private RR type and generator function as argument.

+ 59 - 22
vendor/github.com/miekg/dns/scan.go

@@ -150,6 +150,9 @@ func ReadRR(r io.Reader, file string) (RR, error) {
 // The text "; this is comment" is returned from Comment. Comments inside
 // the RR are returned concatenated along with the RR. Comments on a line
 // by themselves are discarded.
+//
+// Callers should not assume all returned data in an Resource Record is
+// syntactically correct, e.g. illegal base64 in RRSIGs will be returned as-is.
 type ZoneParser struct {
 	c *zlexer
 
@@ -577,10 +580,23 @@ func (zp *ZoneParser) Next() (RR, bool) {
 
 			st = zExpectRdata
 		case zExpectRdata:
-			var rr RR
-			if newFn, ok := TypeToRR[h.Rrtype]; ok && canParseAsRR(h.Rrtype) {
+			var (
+				rr             RR
+				parseAsRFC3597 bool
+			)
+			if newFn, ok := TypeToRR[h.Rrtype]; ok {
 				rr = newFn()
 				*rr.Header() = *h
+
+				// We may be parsing a known RR type using the RFC3597 format.
+				// If so, we handle that here in a generic way.
+				//
+				// This is also true for PrivateRR types which will have the
+				// RFC3597 parsing done for them and the Unpack method called
+				// to populate the RR instead of simply deferring to Parse.
+				if zp.c.Peek().token == "\\#" {
+					parseAsRFC3597 = true
+				}
 			} else {
 				rr = &RFC3597{Hdr: *h}
 			}
@@ -600,13 +616,18 @@ func (zp *ZoneParser) Next() (RR, bool) {
 				return zp.setParseError("unexpected newline", l)
 			}
 
-			if err := rr.parse(zp.c, zp.origin); err != nil {
+			parseAsRR := rr
+			if parseAsRFC3597 {
+				parseAsRR = &RFC3597{Hdr: *h}
+			}
+
+			if err := parseAsRR.parse(zp.c, zp.origin); err != nil {
 				// err is a concrete *ParseError without the file field set.
 				// The setParseError call below will construct a new
 				// *ParseError with file set to zp.file.
 
-				// If err.lex is nil than we have encounter an unknown RR type
-				// in that case we substitute our current lex token.
+				// err.lex may be nil in which case we substitute our current
+				// lex token.
 				if err.lex == (lex{}) {
 					return zp.setParseError(err.err, l)
 				}
@@ -614,6 +635,13 @@ func (zp *ZoneParser) Next() (RR, bool) {
 				return zp.setParseError(err.err, err.lex)
 			}
 
+			if parseAsRFC3597 {
+				err := parseAsRR.(*RFC3597).fromRFC3597(rr)
+				if err != nil {
+					return zp.setParseError(err.Error(), l)
+				}
+			}
+
 			return rr, true
 		}
 	}
@@ -623,18 +651,6 @@ func (zp *ZoneParser) Next() (RR, bool) {
 	return nil, false
 }
 
-// canParseAsRR returns true if the record type can be parsed as a
-// concrete RR. It blacklists certain record types that must be parsed
-// according to RFC 3597 because they lack a presentation format.
-func canParseAsRR(rrtype uint16) bool {
-	switch rrtype {
-	case TypeANY, TypeNULL, TypeOPT, TypeTSIG:
-		return false
-	default:
-		return true
-	}
-}
-
 type zlexer struct {
 	br io.ByteReader
 
@@ -1210,11 +1226,29 @@ func stringToCm(token string) (e, m uint8, ok bool) {
 		if cmeters, err = strconv.Atoi(s[1]); err != nil {
 			return
 		}
+		// There's no point in having more than 2 digits in this part, and would rather make the implementation complicated ('123' should be treated as '12').
+		// So we simply reject it.
+		// We also make sure the first character is a digit to reject '+-' signs.
+		if len(s[1]) > 2 || s[1][0] < '0' || s[1][0] > '9' {
+			return
+		}
+		if len(s[1]) == 1 {
+			// 'nn.1' must be treated as 'nn-meters and 10cm, not 1cm.
+			cmeters *= 10
+		}
+		if s[0] == "" {
+			// This will allow omitting the 'meter' part, like .01 (meaning 0.01m = 1cm).
+			break
+		}
 		fallthrough
 	case 1:
 		if meters, err = strconv.Atoi(s[0]); err != nil {
 			return
 		}
+		// RFC1876 states the max value is 90000000.00.  The latter two conditions enforce it.
+		if s[0][0] < '0' || s[0][0] > '9' || meters > 90000000 || (meters == 90000000 && cmeters != 0) {
+			return
+		}
 	case 0:
 		// huh?
 		return 0, 0, false
@@ -1227,13 +1261,10 @@ func stringToCm(token string) (e, m uint8, ok bool) {
 		e = 0
 		val = cmeters
 	}
-	for val > 10 {
+	for val >= 10 {
 		e++
 		val /= 10
 	}
-	if e > 9 {
-		ok = false
-	}
 	m = uint8(val)
 	return
 }
@@ -1275,6 +1306,9 @@ func appendOrigin(name, origin string) string {
 
 // LOC record helper function
 func locCheckNorth(token string, latitude uint32) (uint32, bool) {
+	if latitude > 90*1000*60*60 {
+		return latitude, false
+	}
 	switch token {
 	case "n", "N":
 		return LOC_EQUATOR + latitude, true
@@ -1286,6 +1320,9 @@ func locCheckNorth(token string, latitude uint32) (uint32, bool) {
 
 // LOC record helper function
 func locCheckEast(token string, longitude uint32) (uint32, bool) {
+	if longitude > 180*1000*60*60 {
+		return longitude, false
+	}
 	switch token {
 	case "e", "E":
 		return LOC_EQUATOR + longitude, true
@@ -1318,7 +1355,7 @@ func stringToNodeID(l lex) (uint64, *ParseError) {
 	if len(l.token) < 19 {
 		return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
 	}
-	// There must be three colons at fixes postitions, if not its a parse error
+	// There must be three colons at fixes positions, if not its a parse error
 	if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
 		return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
 	}

+ 53 - 18
vendor/github.com/miekg/dns/scan_rr.go

@@ -590,7 +590,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError {
 	// North
 	l, _ := c.Next()
 	i, e := strconv.ParseUint(l.token, 10, 32)
-	if e != nil || l.err {
+	if e != nil || l.err || i > 90 {
 		return &ParseError{"", "bad LOC Latitude", l}
 	}
 	rr.Latitude = 1000 * 60 * 60 * uint32(i)
@@ -601,7 +601,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError {
 	if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok {
 		goto East
 	}
-	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err {
+	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 {
 		return &ParseError{"", "bad LOC Latitude minutes", l}
 	} else {
 		rr.Latitude += 1000 * 60 * uint32(i)
@@ -609,7 +609,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError {
 
 	c.Next() // zBlank
 	l, _ = c.Next()
-	if i, err := strconv.ParseFloat(l.token, 32); err != nil || l.err {
+	if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 {
 		return &ParseError{"", "bad LOC Latitude seconds", l}
 	} else {
 		rr.Latitude += uint32(1000 * i)
@@ -627,7 +627,7 @@ East:
 	// East
 	c.Next() // zBlank
 	l, _ = c.Next()
-	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err {
+	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 180 {
 		return &ParseError{"", "bad LOC Longitude", l}
 	} else {
 		rr.Longitude = 1000 * 60 * 60 * uint32(i)
@@ -638,14 +638,14 @@ East:
 	if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok {
 		goto Altitude
 	}
-	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err {
+	if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 {
 		return &ParseError{"", "bad LOC Longitude minutes", l}
 	} else {
 		rr.Longitude += 1000 * 60 * uint32(i)
 	}
 	c.Next() // zBlank
 	l, _ = c.Next()
-	if i, err := strconv.ParseFloat(l.token, 32); err != nil || l.err {
+	if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 {
 		return &ParseError{"", "bad LOC Longitude seconds", l}
 	} else {
 		rr.Longitude += uint32(1000 * i)
@@ -662,13 +662,13 @@ East:
 Altitude:
 	c.Next() // zBlank
 	l, _ = c.Next()
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad LOC Altitude", l}
 	}
 	if l.token[len(l.token)-1] == 'M' || l.token[len(l.token)-1] == 'm' {
 		l.token = l.token[0 : len(l.token)-1]
 	}
-	if i, err := strconv.ParseFloat(l.token, 32); err != nil {
+	if i, err := strconv.ParseFloat(l.token, 64); err != nil {
 		return &ParseError{"", "bad LOC Altitude", l}
 	} else {
 		rr.Altitude = uint32(i*100.0 + 10000000.0 + 0.5)
@@ -722,7 +722,7 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError {
 
 	c.Next()        // zBlank
 	l, _ = c.Next() // zString
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad HIP Hit", l}
 	}
 	rr.Hit = l.token // This can not contain spaces, see RFC 5205 Section 6.
@@ -730,11 +730,15 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError {
 
 	c.Next()        // zBlank
 	l, _ = c.Next() // zString
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad HIP PublicKey", l}
 	}
 	rr.PublicKey = l.token // This cannot contain spaces
-	rr.PublicKeyLength = uint16(base64.StdEncoding.DecodedLen(len(rr.PublicKey)))
+	decodedPK, decodedPKerr := base64.StdEncoding.DecodeString(rr.PublicKey)
+	if decodedPKerr != nil {
+		return &ParseError{"", "bad HIP PublicKey", l}
+	}
+	rr.PublicKeyLength = uint16(len(decodedPK))
 
 	// RendezvousServers (if any)
 	l, _ = c.Next()
@@ -846,6 +850,38 @@ func (rr *CSYNC) parse(c *zlexer, o string) *ParseError {
 	return nil
 }
 
+func (rr *ZONEMD) parse(c *zlexer, o string) *ParseError {
+	l, _ := c.Next()
+	i, e := strconv.ParseUint(l.token, 10, 32)
+	if e != nil || l.err {
+		return &ParseError{"", "bad ZONEMD Serial", l}
+	}
+	rr.Serial = uint32(i)
+
+	c.Next() // zBlank
+	l, _ = c.Next()
+	i, e1 := strconv.ParseUint(l.token, 10, 8)
+	if e1 != nil || l.err {
+		return &ParseError{"", "bad ZONEMD Scheme", l}
+	}
+	rr.Scheme = uint8(i)
+
+	c.Next() // zBlank
+	l, _ = c.Next()
+	i, err := strconv.ParseUint(l.token, 10, 8)
+	if err != nil || l.err {
+		return &ParseError{"", "bad ZONEMD Hash Algorithm", l}
+	}
+	rr.Hash = uint8(i)
+
+	s, e2 := endingToString(c, "bad ZONEMD Digest")
+	if e2 != nil {
+		return e2
+	}
+	rr.Digest = s
+	return nil
+}
+
 func (rr *SIG) parse(c *zlexer, o string) *ParseError { return rr.RRSIG.parse(c, o) }
 
 func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
@@ -893,8 +929,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
 	l, _ = c.Next()
 	if i, err := StringToTime(l.token); err != nil {
 		// Try to see if all numeric and use it as epoch
-		if i, err := strconv.ParseInt(l.token, 10, 64); err == nil {
-			// TODO(miek): error out on > MAX_UINT32, same below
+		if i, err := strconv.ParseUint(l.token, 10, 32); err == nil {
 			rr.Expiration = uint32(i)
 		} else {
 			return &ParseError{"", "bad RRSIG Expiration", l}
@@ -906,7 +941,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
 	c.Next() // zBlank
 	l, _ = c.Next()
 	if i, err := StringToTime(l.token); err != nil {
-		if i, err := strconv.ParseInt(l.token, 10, 64); err == nil {
+		if i, err := strconv.ParseUint(l.token, 10, 32); err == nil {
 			rr.Inception = uint32(i)
 		} else {
 			return &ParseError{"", "bad RRSIG Inception", l}
@@ -998,7 +1033,7 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError {
 	rr.Iterations = uint16(i)
 	c.Next()
 	l, _ = c.Next()
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad NSEC3 Salt", l}
 	}
 	if l.token != "-" {
@@ -1008,7 +1043,7 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError {
 
 	c.Next()
 	l, _ = c.Next()
-	if len(l.token) == 0 || l.err {
+	if l.token == "" || l.err {
 		return &ParseError{"", "bad NSEC3 NextDomain", l}
 	}
 	rr.HashLength = 20 // Fix for NSEC3 (sha1 160 bits)
@@ -1388,7 +1423,7 @@ func (rr *RFC3597) parse(c *zlexer, o string) *ParseError {
 
 	c.Next() // zBlank
 	l, _ = c.Next()
-	rdlength, e := strconv.Atoi(l.token)
+	rdlength, e := strconv.ParseUint(l.token, 10, 16)
 	if e != nil || l.err {
 		return &ParseError{"", "bad RFC3597 Rdata ", l}
 	}
@@ -1397,7 +1432,7 @@ func (rr *RFC3597) parse(c *zlexer, o string) *ParseError {
 	if e1 != nil {
 		return e1
 	}
-	if rdlength*2 != len(s) {
+	if int(rdlength)*2 != len(s) {
 		return &ParseError{"", "bad RFC3597 Rdata", l}
 	}
 	rr.Rdata = s

+ 2 - 2
vendor/github.com/miekg/dns/serve_mux.go

@@ -91,7 +91,7 @@ func (mux *ServeMux) HandleRemove(pattern string) {
 // are redirected to the parent zone (if that is also registered),
 // otherwise the child gets the query.
 //
-// If no handler is found, or there is no question, a standard SERVFAIL
+// If no handler is found, or there is no question, a standard REFUSED
 // message is returned
 func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
 	var h Handler
@@ -102,7 +102,7 @@ func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
 	if h != nil {
 		h.ServeDNS(w, req)
 	} else {
-		HandleFailed(w, req)
+		handleRefused(w, req)
 	}
 }
 

+ 92 - 28
vendor/github.com/miekg/dns/server.go

@@ -72,13 +72,22 @@ type response struct {
 	tsigStatus     error
 	tsigRequestMAC string
 	tsigSecret     map[string]string // the tsig secrets
-	udp            *net.UDPConn      // i/o connection if UDP was used
+	udp            net.PacketConn    // i/o connection if UDP was used
 	tcp            net.Conn          // i/o connection if TCP was used
 	udpSession     *SessionUDP       // oob data to get egress interface right
+	pcSession      net.Addr          // address to use when writing to a generic net.PacketConn
 	writer         Writer            // writer to output the raw DNS bits
 }
 
+// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
+func handleRefused(w ResponseWriter, r *Msg) {
+	m := new(Msg)
+	m.SetRcode(r, RcodeRefused)
+	w.WriteMsg(m)
+}
+
 // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
+// Deprecated: This function is going away.
 func HandleFailed(w ResponseWriter, r *Msg) {
 	m := new(Msg)
 	m.SetRcode(r, RcodeServerFailure)
@@ -139,12 +148,24 @@ type Reader interface {
 	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
 }
 
-// defaultReader is an adapter for the Server struct that implements the Reader interface
-// using the readTCP and readUDP func of the embedded Server.
+// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
+type PacketConnReader interface {
+	Reader
+
+	// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
+	// alter connection properties, for example the read-deadline.
+	ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
+}
+
+// defaultReader is an adapter for the Server struct that implements the Reader and
+// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
+// of the embedded Server.
 type defaultReader struct {
 	*Server
 }
 
+var _ PacketConnReader = defaultReader{}
+
 func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
 	return dr.readTCP(conn, timeout)
 }
@@ -153,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt
 	return dr.readUDP(conn, timeout)
 }
 
+func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+	return dr.readPacketConn(conn, timeout)
+}
+
 // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
 // Implementations should never return a nil Reader.
+// Readers should also implement the optional PacketConnReader interface.
+// PacketConnReader is required to use a generic net.PacketConn.
 type DecorateReader func(Reader) Reader
 
 // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
@@ -294,6 +321,7 @@ func (srv *Server) ListenAndServe() error {
 		}
 		u := l.(*net.UDPConn)
 		if e := setUDPSocketOptions(u); e != nil {
+			u.Close()
 			return e
 		}
 		srv.PacketConn = l
@@ -317,24 +345,22 @@ func (srv *Server) ActivateAndServe() error {
 
 	srv.init()
 
-	pConn := srv.PacketConn
-	l := srv.Listener
-	if pConn != nil {
+	if srv.PacketConn != nil {
 		// Check PacketConn interface's type is valid and value
 		// is not nil
-		if t, ok := pConn.(*net.UDPConn); ok && t != nil {
+		if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
 			if e := setUDPSocketOptions(t); e != nil {
 				return e
 			}
-			srv.started = true
-			unlock()
-			return srv.serveUDP(t)
 		}
+		srv.started = true
+		unlock()
+		return srv.serveUDP(srv.PacketConn)
 	}
-	if l != nil {
+	if srv.Listener != nil {
 		srv.started = true
 		unlock()
-		return srv.serveTCP(l)
+		return srv.serveTCP(srv.Listener)
 	}
 	return &Error{err: "bad listeners"}
 }
@@ -438,18 +464,24 @@ func (srv *Server) serveTCP(l net.Listener) error {
 }
 
 // serveUDP starts a UDP listener for the server.
-func (srv *Server) serveUDP(l *net.UDPConn) error {
+func (srv *Server) serveUDP(l net.PacketConn) error {
 	defer l.Close()
 
-	if srv.NotifyStartedFunc != nil {
-		srv.NotifyStartedFunc()
-	}
-
 	reader := Reader(defaultReader{srv})
 	if srv.DecorateReader != nil {
 		reader = srv.DecorateReader(reader)
 	}
 
+	lUDP, isUDP := l.(*net.UDPConn)
+	readerPC, canPacketConn := reader.(PacketConnReader)
+	if !isUDP && !canPacketConn {
+		return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
+	}
+
+	if srv.NotifyStartedFunc != nil {
+		srv.NotifyStartedFunc()
+	}
+
 	var wg sync.WaitGroup
 	defer func() {
 		wg.Wait()
@@ -459,7 +491,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
 	rtimeout := srv.getReadTimeout()
 	// deadline is not used here
 	for srv.isStarted() {
-		m, s, err := reader.ReadUDP(l, rtimeout)
+		var (
+			m    []byte
+			sPC  net.Addr
+			sUDP *SessionUDP
+			err  error
+		)
+		if isUDP {
+			m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
+		} else {
+			m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
+		}
 		if err != nil {
 			if !srv.isStarted() {
 				return nil
@@ -476,7 +518,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
 			continue
 		}
 		wg.Add(1)
-		go srv.serveUDPPacket(&wg, m, l, s)
+		go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
 	}
 
 	return nil
@@ -538,8 +580,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
 }
 
 // Serve a new UDP request.
-func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
-	w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
+func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
+	w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
 	if srv.DecorateWriter != nil {
 		w.writer = srv.DecorateWriter(w)
 	} else {
@@ -651,6 +693,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
 	return m, s, nil
 }
 
+func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+	srv.lock.RLock()
+	if srv.started {
+		// See the comment in readTCP above.
+		conn.SetReadDeadline(time.Now().Add(timeout))
+	}
+	srv.lock.RUnlock()
+
+	m := srv.udpPool.Get().([]byte)
+	n, addr, err := conn.ReadFrom(m)
+	if err != nil {
+		srv.udpPool.Put(m)
+		return nil, nil, err
+	}
+	m = m[:n]
+	return m, addr, nil
+}
+
 // WriteMsg implements the ResponseWriter.WriteMsg method.
 func (w *response) WriteMsg(m *Msg) (err error) {
 	if w.closed {
@@ -684,17 +744,19 @@ func (w *response) Write(m []byte) (int, error) {
 
 	switch {
 	case w.udp != nil:
-		return WriteToSessionUDP(w.udp, m, w.udpSession)
+		if u, ok := w.udp.(*net.UDPConn); ok {
+			return WriteToSessionUDP(u, m, w.udpSession)
+		}
+		return w.udp.WriteTo(m, w.pcSession)
 	case w.tcp != nil:
 		if len(m) > MaxMsgSize {
 			return 0, &Error{err: "message too large"}
 		}
 
-		l := make([]byte, 2)
-		binary.BigEndian.PutUint16(l, uint16(len(m)))
-
-		n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
-		return int(n), err
+		msg := make([]byte, 2+len(m))
+		binary.BigEndian.PutUint16(msg, uint16(len(m)))
+		copy(msg[2:], m)
+		return w.tcp.Write(msg)
 	default:
 		panic("dns: internal error: udp and tcp both nil")
 	}
@@ -717,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr {
 	switch {
 	case w.udpSession != nil:
 		return w.udpSession.RemoteAddr()
+	case w.pcSession != nil:
+		return w.pcSession
 	case w.tcp != nil:
 		return w.tcp.RemoteAddr()
 	default:
-		panic("dns: internal error: udpSession and tcp both nil")
+		panic("dns: internal error: udpSession, pcSession and tcp are all nil")
 	}
 }
 

+ 3 - 15
vendor/github.com/miekg/dns/sig0.go

@@ -2,7 +2,6 @@ package dns
 
 import (
 	"crypto"
-	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/rsa"
 	"encoding/binary"
@@ -18,7 +17,7 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
 	if k == nil {
 		return nil, ErrPrivKey
 	}
-	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
+	if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 {
 		return nil, ErrKey
 	}
 
@@ -79,13 +78,13 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
 	if k == nil {
 		return ErrKey
 	}
-	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
+	if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 {
 		return ErrKey
 	}
 
 	var hash crypto.Hash
 	switch rr.Algorithm {
-	case DSA, RSASHA1:
+	case RSASHA1:
 		hash = crypto.SHA1
 	case RSASHA256, ECDSAP256SHA256:
 		hash = crypto.SHA256
@@ -178,17 +177,6 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
 	hashed := hasher.Sum(nil)
 	sig := buf[sigend:]
 	switch k.Algorithm {
-	case DSA:
-		pk := k.publicKeyDSA()
-		sig = sig[1:]
-		r := new(big.Int).SetBytes(sig[:len(sig)/2])
-		s := new(big.Int).SetBytes(sig[len(sig)/2:])
-		if pk != nil {
-			if dsa.Verify(pk, hashed, r, s) {
-				return nil
-			}
-			return ErrSig
-		}
 	case RSASHA1, RSASHA256, RSASHA512:
 		pk := k.publicKeyRSA()
 		if pk != nil {

+ 755 - 0
vendor/github.com/miekg/dns/svcb.go

@@ -0,0 +1,755 @@
+package dns
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"net"
+	"sort"
+	"strconv"
+	"strings"
+)
+
+// SVCBKey is the type of the keys used in the SVCB RR.
+type SVCBKey uint16
+
+// Keys defined in draft-ietf-dnsop-svcb-https-01 Section 12.3.2.
+const (
+	SVCB_MANDATORY       SVCBKey = 0
+	SVCB_ALPN            SVCBKey = 1
+	SVCB_NO_DEFAULT_ALPN SVCBKey = 2
+	SVCB_PORT            SVCBKey = 3
+	SVCB_IPV4HINT        SVCBKey = 4
+	SVCB_ECHCONFIG       SVCBKey = 5
+	SVCB_IPV6HINT        SVCBKey = 6
+	svcb_RESERVED        SVCBKey = 65535
+)
+
+var svcbKeyToStringMap = map[SVCBKey]string{
+	SVCB_MANDATORY:       "mandatory",
+	SVCB_ALPN:            "alpn",
+	SVCB_NO_DEFAULT_ALPN: "no-default-alpn",
+	SVCB_PORT:            "port",
+	SVCB_IPV4HINT:        "ipv4hint",
+	SVCB_ECHCONFIG:       "echconfig",
+	SVCB_IPV6HINT:        "ipv6hint",
+}
+
+var svcbStringToKeyMap = reverseSVCBKeyMap(svcbKeyToStringMap)
+
+func reverseSVCBKeyMap(m map[SVCBKey]string) map[string]SVCBKey {
+	n := make(map[string]SVCBKey, len(m))
+	for u, s := range m {
+		n[s] = u
+	}
+	return n
+}
+
+// String takes the numerical code of an SVCB key and returns its name.
+// Returns an empty string for reserved keys.
+// Accepts unassigned keys as well as experimental/private keys.
+func (key SVCBKey) String() string {
+	if x := svcbKeyToStringMap[key]; x != "" {
+		return x
+	}
+	if key == svcb_RESERVED {
+		return ""
+	}
+	return "key" + strconv.FormatUint(uint64(key), 10)
+}
+
+// svcbStringToKey returns the numerical code of an SVCB key.
+// Returns svcb_RESERVED for reserved/invalid keys.
+// Accepts unassigned keys as well as experimental/private keys.
+func svcbStringToKey(s string) SVCBKey {
+	if strings.HasPrefix(s, "key") {
+		a, err := strconv.ParseUint(s[3:], 10, 16)
+		// no leading zeros
+		// key shouldn't be registered
+		if err != nil || a == 65535 || s[3] == '0' || svcbKeyToStringMap[SVCBKey(a)] != "" {
+			return svcb_RESERVED
+		}
+		return SVCBKey(a)
+	}
+	if key, ok := svcbStringToKeyMap[s]; ok {
+		return key
+	}
+	return svcb_RESERVED
+}
+
+func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
+	l, _ := c.Next()
+	i, e := strconv.ParseUint(l.token, 10, 16)
+	if e != nil || l.err {
+		return &ParseError{l.token, "bad SVCB priority", l}
+	}
+	rr.Priority = uint16(i)
+
+	c.Next()        // zBlank
+	l, _ = c.Next() // zString
+	rr.Target = l.token
+
+	name, nameOk := toAbsoluteName(l.token, o)
+	if l.err || !nameOk {
+		return &ParseError{l.token, "bad SVCB Target", l}
+	}
+	rr.Target = name
+
+	// Values (if any)
+	l, _ = c.Next()
+	var xs []SVCBKeyValue
+	// Helps require whitespace between pairs.
+	// Prevents key1000="a"key1001=...
+	canHaveNextKey := true
+	for l.value != zNewline && l.value != zEOF {
+		switch l.value {
+		case zString:
+			if !canHaveNextKey {
+				// The key we can now read was probably meant to be
+				// a part of the last value.
+				return &ParseError{l.token, "bad SVCB value quotation", l}
+			}
+
+			// In key=value pairs, value does not have to be quoted unless value
+			// contains whitespace. And keys don't need to have values.
+			// Similarly, keys with an equality signs after them don't need values.
+			// l.token includes at least up to the first equality sign.
+			idx := strings.IndexByte(l.token, '=')
+			var key, value string
+			if idx < 0 {
+				// Key with no value and no equality sign
+				key = l.token
+			} else if idx == 0 {
+				return &ParseError{l.token, "bad SVCB key", l}
+			} else {
+				key, value = l.token[:idx], l.token[idx+1:]
+
+				if value == "" {
+					// We have a key and an equality sign. Maybe we have nothing
+					// after "=" or we have a double quote.
+					l, _ = c.Next()
+					if l.value == zQuote {
+						// Only needed when value ends with double quotes.
+						// Any value starting with zQuote ends with it.
+						canHaveNextKey = false
+
+						l, _ = c.Next()
+						switch l.value {
+						case zString:
+							// We have a value in double quotes.
+							value = l.token
+							l, _ = c.Next()
+							if l.value != zQuote {
+								return &ParseError{l.token, "SVCB unterminated value", l}
+							}
+						case zQuote:
+							// There's nothing in double quotes.
+						default:
+							return &ParseError{l.token, "bad SVCB value", l}
+						}
+					}
+				}
+			}
+			kv := makeSVCBKeyValue(svcbStringToKey(key))
+			if kv == nil {
+				return &ParseError{l.token, "bad SVCB key", l}
+			}
+			if err := kv.parse(value); err != nil {
+				return &ParseError{l.token, err.Error(), l}
+			}
+			xs = append(xs, kv)
+		case zQuote:
+			return &ParseError{l.token, "SVCB key can't contain double quotes", l}
+		case zBlank:
+			canHaveNextKey = true
+		default:
+			return &ParseError{l.token, "bad SVCB values", l}
+		}
+		l, _ = c.Next()
+	}
+	rr.Value = xs
+	if rr.Priority == 0 && len(xs) > 0 {
+		return &ParseError{l.token, "SVCB aliasform can't have values", l}
+	}
+	return nil
+}
+
+// makeSVCBKeyValue returns an SVCBKeyValue struct with the key or nil for reserved keys.
+func makeSVCBKeyValue(key SVCBKey) SVCBKeyValue {
+	switch key {
+	case SVCB_MANDATORY:
+		return new(SVCBMandatory)
+	case SVCB_ALPN:
+		return new(SVCBAlpn)
+	case SVCB_NO_DEFAULT_ALPN:
+		return new(SVCBNoDefaultAlpn)
+	case SVCB_PORT:
+		return new(SVCBPort)
+	case SVCB_IPV4HINT:
+		return new(SVCBIPv4Hint)
+	case SVCB_ECHCONFIG:
+		return new(SVCBECHConfig)
+	case SVCB_IPV6HINT:
+		return new(SVCBIPv6Hint)
+	case svcb_RESERVED:
+		return nil
+	default:
+		e := new(SVCBLocal)
+		e.KeyCode = key
+		return e
+	}
+}
+
+// SVCB RR. See RFC xxxx (https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-01).
+type SVCB struct {
+	Hdr      RR_Header
+	Priority uint16
+	Target   string         `dns:"domain-name"`
+	Value    []SVCBKeyValue `dns:"pairs"` // Value must be empty if Priority is zero.
+}
+
+// HTTPS RR. Everything valid for SVCB applies to HTTPS as well.
+// Except that the HTTPS record is intended for use with the HTTP and HTTPS protocols.
+type HTTPS struct {
+	SVCB
+}
+
+func (rr *HTTPS) String() string {
+	return rr.SVCB.String()
+}
+
+func (rr *HTTPS) parse(c *zlexer, o string) *ParseError {
+	return rr.SVCB.parse(c, o)
+}
+
+// SVCBKeyValue defines a key=value pair for the SVCB RR type.
+// An SVCB RR can have multiple SVCBKeyValues appended to it.
+type SVCBKeyValue interface {
+	Key() SVCBKey          // Key returns the numerical key code.
+	pack() ([]byte, error) // pack returns the encoded value.
+	unpack([]byte) error   // unpack sets the value.
+	String() string        // String returns the string representation of the value.
+	parse(string) error    // parse sets the value to the given string representation of the value.
+	copy() SVCBKeyValue    // copy returns a deep-copy of the pair.
+	len() int              // len returns the length of value in the wire format.
+}
+
+// SVCBMandatory pair adds to required keys that must be interpreted for the RR
+// to be functional.
+// Basic use pattern for creating a mandatory option:
+//
+//	s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
+//	e := new(dns.SVCBMandatory)
+//	e.Code = []uint16{65403}
+//	s.Value = append(s.Value, e)
+type SVCBMandatory struct {
+	Code []SVCBKey // Must not include mandatory
+}
+
+func (*SVCBMandatory) Key() SVCBKey { return SVCB_MANDATORY }
+
+func (s *SVCBMandatory) String() string {
+	str := make([]string, len(s.Code))
+	for i, e := range s.Code {
+		str[i] = e.String()
+	}
+	return strings.Join(str, ",")
+}
+
+func (s *SVCBMandatory) pack() ([]byte, error) {
+	codes := append([]SVCBKey(nil), s.Code...)
+	sort.Slice(codes, func(i, j int) bool {
+		return codes[i] < codes[j]
+	})
+	b := make([]byte, 2*len(codes))
+	for i, e := range codes {
+		binary.BigEndian.PutUint16(b[2*i:], uint16(e))
+	}
+	return b, nil
+}
+
+func (s *SVCBMandatory) unpack(b []byte) error {
+	if len(b)%2 != 0 {
+		return errors.New("dns: svcbmandatory: value length is not a multiple of 2")
+	}
+	codes := make([]SVCBKey, 0, len(b)/2)
+	for i := 0; i < len(b); i += 2 {
+		// We assume strictly increasing order.
+		codes = append(codes, SVCBKey(binary.BigEndian.Uint16(b[i:])))
+	}
+	s.Code = codes
+	return nil
+}
+
+func (s *SVCBMandatory) parse(b string) error {
+	str := strings.Split(b, ",")
+	codes := make([]SVCBKey, 0, len(str))
+	for _, e := range str {
+		codes = append(codes, svcbStringToKey(e))
+	}
+	s.Code = codes
+	return nil
+}
+
+func (s *SVCBMandatory) len() int {
+	return 2 * len(s.Code)
+}
+
+func (s *SVCBMandatory) copy() SVCBKeyValue {
+	return &SVCBMandatory{
+		append([]SVCBKey(nil), s.Code...),
+	}
+}
+
+// SVCBAlpn pair is used to list supported connection protocols.
+// Protocol ids can be found at:
+// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids
+// Basic use pattern for creating an alpn option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBAlpn)
+//	e.Alpn = []string{"h2", "http/1.1"}
+//	h.Value = append(o.Value, e)
+type SVCBAlpn struct {
+	Alpn []string
+}
+
+func (*SVCBAlpn) Key() SVCBKey     { return SVCB_ALPN }
+func (s *SVCBAlpn) String() string { return strings.Join(s.Alpn, ",") }
+
+func (s *SVCBAlpn) pack() ([]byte, error) {
+	// Liberally estimate the size of an alpn as 10 octets
+	b := make([]byte, 0, 10*len(s.Alpn))
+	for _, e := range s.Alpn {
+		if e == "" {
+			return nil, errors.New("dns: svcbalpn: empty alpn-id")
+		}
+		if len(e) > 255 {
+			return nil, errors.New("dns: svcbalpn: alpn-id too long")
+		}
+		b = append(b, byte(len(e)))
+		b = append(b, e...)
+	}
+	return b, nil
+}
+
+func (s *SVCBAlpn) unpack(b []byte) error {
+	// Estimate the size of the smallest alpn as 4 bytes
+	alpn := make([]string, 0, len(b)/4)
+	for i := 0; i < len(b); {
+		length := int(b[i])
+		i++
+		if i+length > len(b) {
+			return errors.New("dns: svcbalpn: alpn array overflowing")
+		}
+		alpn = append(alpn, string(b[i:i+length]))
+		i += length
+	}
+	s.Alpn = alpn
+	return nil
+}
+
+func (s *SVCBAlpn) parse(b string) error {
+	s.Alpn = strings.Split(b, ",")
+	return nil
+}
+
+func (s *SVCBAlpn) len() int {
+	var l int
+	for _, e := range s.Alpn {
+		l += 1 + len(e)
+	}
+	return l
+}
+
+func (s *SVCBAlpn) copy() SVCBKeyValue {
+	return &SVCBAlpn{
+		append([]string(nil), s.Alpn...),
+	}
+}
+
+// SVCBNoDefaultAlpn pair signifies no support for default connection protocols.
+// Basic use pattern for creating a no-default-alpn option:
+//
+//	s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
+//	e := new(dns.SVCBNoDefaultAlpn)
+//	s.Value = append(s.Value, e)
+type SVCBNoDefaultAlpn struct{}
+
+func (*SVCBNoDefaultAlpn) Key() SVCBKey          { return SVCB_NO_DEFAULT_ALPN }
+func (*SVCBNoDefaultAlpn) copy() SVCBKeyValue    { return &SVCBNoDefaultAlpn{} }
+func (*SVCBNoDefaultAlpn) pack() ([]byte, error) { return []byte{}, nil }
+func (*SVCBNoDefaultAlpn) String() string        { return "" }
+func (*SVCBNoDefaultAlpn) len() int              { return 0 }
+
+func (*SVCBNoDefaultAlpn) unpack(b []byte) error {
+	if len(b) != 0 {
+		return errors.New("dns: svcbnodefaultalpn: no_default_alpn must have no value")
+	}
+	return nil
+}
+
+func (*SVCBNoDefaultAlpn) parse(b string) error {
+	if b != "" {
+		return errors.New("dns: svcbnodefaultalpn: no_default_alpn must have no value")
+	}
+	return nil
+}
+
+// SVCBPort pair defines the port for connection.
+// Basic use pattern for creating a port option:
+//
+//	s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
+//	e := new(dns.SVCBPort)
+//	e.Port = 80
+//	s.Value = append(s.Value, e)
+type SVCBPort struct {
+	Port uint16
+}
+
+func (*SVCBPort) Key() SVCBKey         { return SVCB_PORT }
+func (*SVCBPort) len() int             { return 2 }
+func (s *SVCBPort) String() string     { return strconv.FormatUint(uint64(s.Port), 10) }
+func (s *SVCBPort) copy() SVCBKeyValue { return &SVCBPort{s.Port} }
+
+func (s *SVCBPort) unpack(b []byte) error {
+	if len(b) != 2 {
+		return errors.New("dns: svcbport: port length is not exactly 2 octets")
+	}
+	s.Port = binary.BigEndian.Uint16(b)
+	return nil
+}
+
+func (s *SVCBPort) pack() ([]byte, error) {
+	b := make([]byte, 2)
+	binary.BigEndian.PutUint16(b, s.Port)
+	return b, nil
+}
+
+func (s *SVCBPort) parse(b string) error {
+	port, err := strconv.ParseUint(b, 10, 16)
+	if err != nil {
+		return errors.New("dns: svcbport: port out of range")
+	}
+	s.Port = uint16(port)
+	return nil
+}
+
+// SVCBIPv4Hint pair suggests an IPv4 address which may be used to open connections
+// if A and AAAA record responses for SVCB's Target domain haven't been received.
+// In that case, optionally, A and AAAA requests can be made, after which the connection
+// to the hinted IP address may be terminated and a new connection may be opened.
+// Basic use pattern for creating an ipv4hint option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBIPv4Hint)
+//	e.Hint = []net.IP{net.IPv4(1,1,1,1).To4()}
+//
+//  Or
+//
+//	e.Hint = []net.IP{net.ParseIP("1.1.1.1").To4()}
+//	h.Value = append(h.Value, e)
+type SVCBIPv4Hint struct {
+	Hint []net.IP
+}
+
+func (*SVCBIPv4Hint) Key() SVCBKey { return SVCB_IPV4HINT }
+func (s *SVCBIPv4Hint) len() int   { return 4 * len(s.Hint) }
+
+func (s *SVCBIPv4Hint) pack() ([]byte, error) {
+	b := make([]byte, 0, 4*len(s.Hint))
+	for _, e := range s.Hint {
+		x := e.To4()
+		if x == nil {
+			return nil, errors.New("dns: svcbipv4hint: expected ipv4, hint is ipv6")
+		}
+		b = append(b, x...)
+	}
+	return b, nil
+}
+
+func (s *SVCBIPv4Hint) unpack(b []byte) error {
+	if len(b) == 0 || len(b)%4 != 0 {
+		return errors.New("dns: svcbipv4hint: ipv4 address byte array length is not a multiple of 4")
+	}
+	x := make([]net.IP, 0, len(b)/4)
+	for i := 0; i < len(b); i += 4 {
+		x = append(x, net.IP(b[i:i+4]))
+	}
+	s.Hint = x
+	return nil
+}
+
+func (s *SVCBIPv4Hint) String() string {
+	str := make([]string, len(s.Hint))
+	for i, e := range s.Hint {
+		x := e.To4()
+		if x == nil {
+			return "<nil>"
+		}
+		str[i] = x.String()
+	}
+	return strings.Join(str, ",")
+}
+
+func (s *SVCBIPv4Hint) parse(b string) error {
+	if strings.Contains(b, ":") {
+		return errors.New("dns: svcbipv4hint: expected ipv4, got ipv6")
+	}
+	str := strings.Split(b, ",")
+	dst := make([]net.IP, len(str))
+	for i, e := range str {
+		ip := net.ParseIP(e).To4()
+		if ip == nil {
+			return errors.New("dns: svcbipv4hint: bad ip")
+		}
+		dst[i] = ip
+	}
+	s.Hint = dst
+	return nil
+}
+
+func (s *SVCBIPv4Hint) copy() SVCBKeyValue {
+	hint := make([]net.IP, len(s.Hint))
+	for i, ip := range s.Hint {
+		hint[i] = copyIP(ip)
+	}
+
+	return &SVCBIPv4Hint{
+		Hint: hint,
+	}
+}
+
+// SVCBECHConfig pair contains the ECHConfig structure defined in draft-ietf-tls-esni [RFC xxxx].
+// Basic use pattern for creating an echconfig option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBECHConfig)
+//	e.ECH = []byte{0xfe, 0x08, ...}
+//	h.Value = append(h.Value, e)
+type SVCBECHConfig struct {
+	ECH []byte
+}
+
+func (*SVCBECHConfig) Key() SVCBKey     { return SVCB_ECHCONFIG }
+func (s *SVCBECHConfig) String() string { return toBase64(s.ECH) }
+func (s *SVCBECHConfig) len() int       { return len(s.ECH) }
+
+func (s *SVCBECHConfig) pack() ([]byte, error) {
+	return append([]byte(nil), s.ECH...), nil
+}
+
+func (s *SVCBECHConfig) copy() SVCBKeyValue {
+	return &SVCBECHConfig{
+		append([]byte(nil), s.ECH...),
+	}
+}
+
+func (s *SVCBECHConfig) unpack(b []byte) error {
+	s.ECH = append([]byte(nil), b...)
+	return nil
+}
+func (s *SVCBECHConfig) parse(b string) error {
+	x, err := fromBase64([]byte(b))
+	if err != nil {
+		return errors.New("dns: svcbechconfig: bad base64 echconfig")
+	}
+	s.ECH = x
+	return nil
+}
+
+// SVCBIPv6Hint pair suggests an IPv6 address which may be used to open connections
+// if A and AAAA record responses for SVCB's Target domain haven't been received.
+// In that case, optionally, A and AAAA requests can be made, after which the
+// connection to the hinted IP address may be terminated and a new connection may be opened.
+// Basic use pattern for creating an ipv6hint option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBIPv6Hint)
+//	e.Hint = []net.IP{net.ParseIP("2001:db8::1")}
+//	h.Value = append(h.Value, e)
+type SVCBIPv6Hint struct {
+	Hint []net.IP
+}
+
+func (*SVCBIPv6Hint) Key() SVCBKey { return SVCB_IPV6HINT }
+func (s *SVCBIPv6Hint) len() int   { return 16 * len(s.Hint) }
+
+func (s *SVCBIPv6Hint) pack() ([]byte, error) {
+	b := make([]byte, 0, 16*len(s.Hint))
+	for _, e := range s.Hint {
+		if len(e) != net.IPv6len || e.To4() != nil {
+			return nil, errors.New("dns: svcbipv6hint: expected ipv6, hint is ipv4")
+		}
+		b = append(b, e...)
+	}
+	return b, nil
+}
+
+func (s *SVCBIPv6Hint) unpack(b []byte) error {
+	if len(b) == 0 || len(b)%16 != 0 {
+		return errors.New("dns: svcbipv6hint: ipv6 address byte array length not a multiple of 16")
+	}
+	x := make([]net.IP, 0, len(b)/16)
+	for i := 0; i < len(b); i += 16 {
+		ip := net.IP(b[i : i+16])
+		if ip.To4() != nil {
+			return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4")
+		}
+		x = append(x, ip)
+	}
+	s.Hint = x
+	return nil
+}
+
+func (s *SVCBIPv6Hint) String() string {
+	str := make([]string, len(s.Hint))
+	for i, e := range s.Hint {
+		if x := e.To4(); x != nil {
+			return "<nil>"
+		}
+		str[i] = e.String()
+	}
+	return strings.Join(str, ",")
+}
+
+func (s *SVCBIPv6Hint) parse(b string) error {
+	if strings.Contains(b, ".") {
+		return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4")
+	}
+	str := strings.Split(b, ",")
+	dst := make([]net.IP, len(str))
+	for i, e := range str {
+		ip := net.ParseIP(e)
+		if ip == nil {
+			return errors.New("dns: svcbipv6hint: bad ip")
+		}
+		dst[i] = ip
+	}
+	s.Hint = dst
+	return nil
+}
+
+func (s *SVCBIPv6Hint) copy() SVCBKeyValue {
+	hint := make([]net.IP, len(s.Hint))
+	for i, ip := range s.Hint {
+		hint[i] = copyIP(ip)
+	}
+
+	return &SVCBIPv6Hint{
+		Hint: hint,
+	}
+}
+
+// SVCBLocal pair is intended for experimental/private use. The key is recommended
+// to be in the range [SVCB_PRIVATE_LOWER, SVCB_PRIVATE_UPPER].
+// Basic use pattern for creating a keyNNNNN option:
+//
+//	h := new(dns.HTTPS)
+//	h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
+//	e := new(dns.SVCBLocal)
+//	e.KeyCode = 65400
+//	e.Data = []byte("abc")
+//	h.Value = append(h.Value, e)
+type SVCBLocal struct {
+	KeyCode SVCBKey // Never 65535 or any assigned keys.
+	Data    []byte  // All byte sequences are allowed.
+}
+
+func (s *SVCBLocal) Key() SVCBKey          { return s.KeyCode }
+func (s *SVCBLocal) pack() ([]byte, error) { return append([]byte(nil), s.Data...), nil }
+func (s *SVCBLocal) len() int              { return len(s.Data) }
+
+func (s *SVCBLocal) unpack(b []byte) error {
+	s.Data = append([]byte(nil), b...)
+	return nil
+}
+
+func (s *SVCBLocal) String() string {
+	var str strings.Builder
+	str.Grow(4 * len(s.Data))
+	for _, e := range s.Data {
+		if ' ' <= e && e <= '~' {
+			switch e {
+			case '"', ';', ' ', '\\':
+				str.WriteByte('\\')
+				str.WriteByte(e)
+			default:
+				str.WriteByte(e)
+			}
+		} else {
+			str.WriteString(escapeByte(e))
+		}
+	}
+	return str.String()
+}
+
+func (s *SVCBLocal) parse(b string) error {
+	data := make([]byte, 0, len(b))
+	for i := 0; i < len(b); {
+		if b[i] != '\\' {
+			data = append(data, b[i])
+			i++
+			continue
+		}
+		if i+1 == len(b) {
+			return errors.New("dns: svcblocal: svcb private/experimental key escape unterminated")
+		}
+		if isDigit(b[i+1]) {
+			if i+3 < len(b) && isDigit(b[i+2]) && isDigit(b[i+3]) {
+				a, err := strconv.ParseUint(b[i+1:i+4], 10, 8)
+				if err == nil {
+					i += 4
+					data = append(data, byte(a))
+					continue
+				}
+			}
+			return errors.New("dns: svcblocal: svcb private/experimental key bad escaped octet")
+		} else {
+			data = append(data, b[i+1])
+			i += 2
+		}
+	}
+	s.Data = data
+	return nil
+}
+
+func (s *SVCBLocal) copy() SVCBKeyValue {
+	return &SVCBLocal{s.KeyCode,
+		append([]byte(nil), s.Data...),
+	}
+}
+
+func (rr *SVCB) String() string {
+	s := rr.Hdr.String() +
+		strconv.Itoa(int(rr.Priority)) + " " +
+		sprintName(rr.Target)
+	for _, e := range rr.Value {
+		s += " " + e.Key().String() + "=\"" + e.String() + "\""
+	}
+	return s
+}
+
+// areSVCBPairArraysEqual checks if SVCBKeyValue arrays are equal after sorting their
+// copies. arrA and arrB have equal lengths, otherwise zduplicate.go wouldn't call this function.
+func areSVCBPairArraysEqual(a []SVCBKeyValue, b []SVCBKeyValue) bool {
+	a = append([]SVCBKeyValue(nil), a...)
+	b = append([]SVCBKeyValue(nil), b...)
+	sort.Slice(a, func(i, j int) bool { return a[i].Key() < a[j].Key() })
+	sort.Slice(b, func(i, j int) bool { return b[i].Key() < b[j].Key() })
+	for i, e := range a {
+		if e.Key() != b[i].Key() {
+			return false
+		}
+		b1, err1 := e.pack()
+		b2, err2 := b[i].pack()
+		if err1 != nil || err2 != nil || !bytes.Equal(b1, b2) {
+			return false
+		}
+	}
+	return true
+}

+ 99 - 59
vendor/github.com/miekg/dns/tsig.go

@@ -2,7 +2,6 @@ package dns
 
 import (
 	"crypto/hmac"
-	"crypto/md5"
 	"crypto/sha1"
 	"crypto/sha256"
 	"crypto/sha512"
@@ -16,12 +15,65 @@ import (
 
 // HMAC hashing codes. These are transmitted as domain names.
 const (
-	HmacMD5    = "hmac-md5.sig-alg.reg.int."
 	HmacSHA1   = "hmac-sha1."
+	HmacSHA224 = "hmac-sha224."
 	HmacSHA256 = "hmac-sha256."
+	HmacSHA384 = "hmac-sha384."
 	HmacSHA512 = "hmac-sha512."
+
+	HmacMD5 = "hmac-md5.sig-alg.reg.int." // Deprecated: HmacMD5 is no longer supported.
 )
 
+// TsigProvider provides the API to plug-in a custom TSIG implementation.
+type TsigProvider interface {
+	// Generate is passed the DNS message to be signed and the partial TSIG RR. It returns the signature and nil, otherwise an error.
+	Generate(msg []byte, t *TSIG) ([]byte, error)
+	// Verify is passed the DNS message to be verified and the TSIG RR. If the signature is valid it will return nil, otherwise an error.
+	Verify(msg []byte, t *TSIG) error
+}
+
+type tsigHMACProvider string
+
+func (key tsigHMACProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
+	// If we barf here, the caller is to blame
+	rawsecret, err := fromBase64([]byte(key))
+	if err != nil {
+		return nil, err
+	}
+	var h hash.Hash
+	switch CanonicalName(t.Algorithm) {
+	case HmacSHA1:
+		h = hmac.New(sha1.New, rawsecret)
+	case HmacSHA224:
+		h = hmac.New(sha256.New224, rawsecret)
+	case HmacSHA256:
+		h = hmac.New(sha256.New, rawsecret)
+	case HmacSHA384:
+		h = hmac.New(sha512.New384, rawsecret)
+	case HmacSHA512:
+		h = hmac.New(sha512.New, rawsecret)
+	default:
+		return nil, ErrKeyAlg
+	}
+	h.Write(msg)
+	return h.Sum(nil), nil
+}
+
+func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
+	b, err := key.Generate(msg, t)
+	if err != nil {
+		return err
+	}
+	mac, err := hex.DecodeString(t.MAC)
+	if err != nil {
+		return err
+	}
+	if !hmac.Equal(b, mac) {
+		return ErrSig
+	}
+	return nil
+}
+
 // TSIG is the RR the holds the transaction signature of a message.
 // See RFC 2845 and RFC 4635.
 type TSIG struct {
@@ -54,8 +106,8 @@ func (rr *TSIG) String() string {
 	return s
 }
 
-func (rr *TSIG) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on TSIG")
+func (*TSIG) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "TSIG records do not have a presentation format"}
 }
 
 // The following values must be put in wireformat, so that the MAC can be calculated.
@@ -96,14 +148,13 @@ type timerWireFmt struct {
 // timersOnly is false.
 // If something goes wrong an error is returned, otherwise it is nil.
 func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) {
+	return tsigGenerateProvider(m, tsigHMACProvider(secret), requestMAC, timersOnly)
+}
+
+func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) {
 	if m.IsTsig() == nil {
 		panic("dns: TSIG not last RR in additional")
 	}
-	// If we barf here, the caller is to blame
-	rawsecret, err := fromBase64([]byte(secret))
-	if err != nil {
-		return nil, "", err
-	}
 
 	rr := m.Extra[len(m.Extra)-1].(*TSIG)
 	m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
@@ -111,32 +162,21 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
 	if err != nil {
 		return nil, "", err
 	}
-	buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
+	buf, err := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
+	if err != nil {
+		return nil, "", err
+	}
 
 	t := new(TSIG)
-	var h hash.Hash
-	switch CanonicalName(rr.Algorithm) {
-	case HmacMD5:
-		h = hmac.New(md5.New, rawsecret)
-	case HmacSHA1:
-		h = hmac.New(sha1.New, rawsecret)
-	case HmacSHA256:
-		h = hmac.New(sha256.New, rawsecret)
-	case HmacSHA512:
-		h = hmac.New(sha512.New, rawsecret)
-	default:
-		return nil, "", ErrKeyAlg
+	// Copy all TSIG fields except MAC and its size, which are filled using the computed digest.
+	*t = *rr
+	mac, err := provider.Generate(buf, rr)
+	if err != nil {
+		return nil, "", err
 	}
-	h.Write(buf)
-	t.MAC = hex.EncodeToString(h.Sum(nil))
+	t.MAC = hex.EncodeToString(mac)
 	t.MACSize = uint16(len(t.MAC) / 2) // Size is half!
 
-	t.Hdr = RR_Header{Name: rr.Hdr.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0}
-	t.Fudge = rr.Fudge
-	t.TimeSigned = rr.TimeSigned
-	t.Algorithm = rr.Algorithm
-	t.OrigId = m.Id
-
 	tbuf := make([]byte, Len(t))
 	off, err := PackRR(t, tbuf, 0, nil, false)
 	if err != nil {
@@ -153,26 +193,34 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
 // If the signature does not validate err contains the
 // error, otherwise it is nil.
 func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
-	rawsecret, err := fromBase64([]byte(secret))
-	if err != nil {
-		return err
-	}
+	return tsigVerify(msg, tsigHMACProvider(secret), requestMAC, timersOnly, uint64(time.Now().Unix()))
+}
+
+func tsigVerifyProvider(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool) error {
+	return tsigVerify(msg, provider, requestMAC, timersOnly, uint64(time.Now().Unix()))
+}
+
+// actual implementation of TsigVerify, taking the current time ('now') as a parameter for the convenience of tests.
+func tsigVerify(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool, now uint64) error {
 	// Strip the TSIG from the incoming msg
 	stripped, tsig, err := stripTsig(msg)
 	if err != nil {
 		return err
 	}
 
-	msgMAC, err := hex.DecodeString(tsig.MAC)
+	buf, err := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
 	if err != nil {
 		return err
 	}
 
-	buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
+	if err := provider.Verify(buf, tsig); err != nil {
+		return err
+	}
 
 	// Fudge factor works both ways. A message can arrive before it was signed because
 	// of clock skew.
-	now := uint64(time.Now().Unix())
+	// We check this after verifying the signature, following draft-ietf-dnsop-rfc2845bis
+	// instead of RFC2845, in order to prevent a security vulnerability as reported in CVE-2017-3142/3143.
 	ti := now - tsig.TimeSigned
 	if now < tsig.TimeSigned {
 		ti = tsig.TimeSigned - now
@@ -181,28 +229,11 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
 		return ErrTime
 	}
 
-	var h hash.Hash
-	switch CanonicalName(tsig.Algorithm) {
-	case HmacMD5:
-		h = hmac.New(md5.New, rawsecret)
-	case HmacSHA1:
-		h = hmac.New(sha1.New, rawsecret)
-	case HmacSHA256:
-		h = hmac.New(sha256.New, rawsecret)
-	case HmacSHA512:
-		h = hmac.New(sha512.New, rawsecret)
-	default:
-		return ErrKeyAlg
-	}
-	h.Write(buf)
-	if !hmac.Equal(h.Sum(nil), msgMAC) {
-		return ErrSig
-	}
 	return nil
 }
 
 // Create a wiredata buffer for the MAC calculation.
-func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte {
+func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) ([]byte, error) {
 	var buf []byte
 	if rr.TimeSigned == 0 {
 		rr.TimeSigned = uint64(time.Now().Unix())
@@ -219,7 +250,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 		m.MACSize = uint16(len(requestMAC) / 2)
 		m.MAC = requestMAC
 		buf = make([]byte, len(requestMAC)) // long enough
-		n, _ := packMacWire(m, buf)
+		n, err := packMacWire(m, buf)
+		if err != nil {
+			return nil, err
+		}
 		buf = buf[:n]
 	}
 
@@ -228,7 +262,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 		tsig := new(timerWireFmt)
 		tsig.TimeSigned = rr.TimeSigned
 		tsig.Fudge = rr.Fudge
-		n, _ := packTimerWire(tsig, tsigvar)
+		n, err := packTimerWire(tsig, tsigvar)
+		if err != nil {
+			return nil, err
+		}
 		tsigvar = tsigvar[:n]
 	} else {
 		tsig := new(tsigWireFmt)
@@ -241,7 +278,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 		tsig.Error = rr.Error
 		tsig.OtherLen = rr.OtherLen
 		tsig.OtherData = rr.OtherData
-		n, _ := packTsigWire(tsig, tsigvar)
+		n, err := packTsigWire(tsig, tsigvar)
+		if err != nil {
+			return nil, err
+		}
 		tsigvar = tsigvar[:n]
 	}
 
@@ -251,7 +291,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
 	} else {
 		buf = append(msgbuf, tsigvar...)
 	}
-	return buf
+	return buf, nil
 }
 
 // Strip the TSIG from the raw message.

+ 79 - 50
vendor/github.com/miekg/dns/types.go

@@ -81,6 +81,9 @@ const (
 	TypeCDNSKEY    uint16 = 60
 	TypeOPENPGPKEY uint16 = 61
 	TypeCSYNC      uint16 = 62
+	TypeZONEMD     uint16 = 63
+	TypeSVCB       uint16 = 64
+	TypeHTTPS      uint16 = 65
 	TypeSPF        uint16 = 99
 	TypeUINFO      uint16 = 100
 	TypeUID        uint16 = 101
@@ -148,6 +151,14 @@ const (
 	OpcodeUpdate = 5
 )
 
+// Used in ZONEMD https://tools.ietf.org/html/rfc8976
+const (
+	ZoneMDSchemeSimple = 1
+
+	ZoneMDHashAlgSHA384 = 1
+	ZoneMDHashAlgSHA512 = 2
+)
+
 // Header is the wire format for the DNS packet header.
 type Header struct {
 	Id                                 uint16
@@ -243,8 +254,8 @@ type ANY struct {
 
 func (rr *ANY) String() string { return rr.Hdr.String() }
 
-func (rr *ANY) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on ANY")
+func (*ANY) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "ANY records do not have a presentation format"}
 }
 
 // NULL RR. See RFC 1035.
@@ -258,8 +269,8 @@ func (rr *NULL) String() string {
 	return ";" + rr.Hdr.String() + rr.Data
 }
 
-func (rr *NULL) parse(c *zlexer, origin string) *ParseError {
-	panic("dns: internal error: parse should never be called on NULL")
+func (*NULL) parse(c *zlexer, origin string) *ParseError {
+	return &ParseError{err: "NULL records do not have a presentation format"}
 }
 
 // CNAME RR. See RFC 1034.
@@ -445,45 +456,38 @@ func sprintName(s string) string {
 	var dst strings.Builder
 
 	for i := 0; i < len(s); {
-		if i+1 < len(s) && s[i] == '\\' && s[i+1] == '.' {
+		if s[i] == '.' {
 			if dst.Len() != 0 {
-				dst.WriteString(s[i : i+2])
+				dst.WriteByte('.')
 			}
-			i += 2
+			i++
 			continue
 		}
 
 		b, n := nextByte(s, i)
 		if n == 0 {
-			i++
-			continue
-		}
-		if b == '.' {
-			if dst.Len() != 0 {
-				dst.WriteByte('.')
+			// Drop "dangling" incomplete escapes.
+			if dst.Len() == 0 {
+				return s[:i]
 			}
-			i += n
-			continue
+			break
 		}
-		switch b {
-		case ' ', '\'', '@', ';', '(', ')', '"', '\\': // additional chars to escape
+		if isDomainNameLabelSpecial(b) {
 			if dst.Len() == 0 {
 				dst.Grow(len(s) * 2)
 				dst.WriteString(s[:i])
 			}
 			dst.WriteByte('\\')
 			dst.WriteByte(b)
-		default:
-			if ' ' <= b && b <= '~' {
-				if dst.Len() != 0 {
-					dst.WriteByte(b)
-				}
-			} else {
-				if dst.Len() == 0 {
-					dst.Grow(len(s) * 2)
-					dst.WriteString(s[:i])
-				}
-				dst.WriteString(escapeByte(b))
+		} else if b < ' ' || b > '~' { // unprintable, use \DDD
+			if dst.Len() == 0 {
+				dst.Grow(len(s) * 2)
+				dst.WriteString(s[:i])
+			}
+			dst.WriteString(escapeByte(b))
+		} else {
+			if dst.Len() != 0 {
+				dst.WriteByte(b)
 			}
 		}
 		i += n
@@ -506,15 +510,10 @@ func sprintTxtOctet(s string) string {
 		}
 
 		b, n := nextByte(s, i)
-		switch {
-		case n == 0:
+		if n == 0 {
 			i++ // dangling back slash
-		case b == '.':
-			dst.WriteByte('.')
-		case b < ' ' || b > '~':
-			dst.WriteString(escapeByte(b))
-		default:
-			dst.WriteByte(b)
+		} else {
+			writeTXTStringByte(&dst, b)
 		}
 		i += n
 	}
@@ -590,6 +589,17 @@ func escapeByte(b byte) string {
 	return escapedByteLarge[int(b)*4 : int(b)*4+4]
 }
 
+// isDomainNameLabelSpecial returns true if
+// a domain name label byte should be prefixed
+// with an escaping backslash.
+func isDomainNameLabelSpecial(b byte) bool {
+	switch b {
+	case '.', ' ', '\'', '@', ';', '(', ')', '"', '\\':
+		return true
+	}
+	return false
+}
+
 func nextByte(s string, offset int) (byte, int) {
 	if offset >= len(s) {
 		return 0, 0
@@ -1121,6 +1131,7 @@ type URI struct {
 	Target   string `dns:"octet"`
 }
 
+// rr.Target to be parsed as a sequence of character encoded octets according to RFC 3986
 func (rr *URI) String() string {
 	return rr.Hdr.String() + strconv.Itoa(int(rr.Priority)) +
 		" " + strconv.Itoa(int(rr.Weight)) + " " + sprintTxtOctet(rr.Target)
@@ -1282,6 +1293,7 @@ type CAA struct {
 	Value string `dns:"octet"`
 }
 
+// rr.Value Is the character-string encoding of the value field as specified in RFC 1035, Section 5.1.
 func (rr *CAA) String() string {
 	return rr.Hdr.String() + strconv.Itoa(int(rr.Flag)) + " " + rr.Tag + " " + sprintTxtOctet(rr.Value)
 }
@@ -1358,6 +1370,23 @@ func (rr *CSYNC) len(off int, compression map[string]struct{}) int {
 	return l
 }
 
+// ZONEMD RR, from draft-ietf-dnsop-dns-zone-digest
+type ZONEMD struct {
+	Hdr    RR_Header
+	Serial uint32
+	Scheme uint8
+	Hash   uint8
+	Digest string `dns:"hex"`
+}
+
+func (rr *ZONEMD) String() string {
+	return rr.Hdr.String() +
+		strconv.Itoa(int(rr.Serial)) +
+		" " + strconv.Itoa(int(rr.Scheme)) +
+		" " + strconv.Itoa(int(rr.Hash)) +
+		" " + rr.Digest
+}
+
 // APL RR. See RFC 3123.
 type APL struct {
 	Hdr      RR_Header
@@ -1384,13 +1413,13 @@ func (rr *APL) String() string {
 }
 
 // str returns presentation form of the APL prefix.
-func (p *APLPrefix) str() string {
+func (a *APLPrefix) str() string {
 	var sb strings.Builder
-	if p.Negation {
+	if a.Negation {
 		sb.WriteByte('!')
 	}
 
-	switch len(p.Network.IP) {
+	switch len(a.Network.IP) {
 	case net.IPv4len:
 		sb.WriteByte('1')
 	case net.IPv6len:
@@ -1399,20 +1428,20 @@ func (p *APLPrefix) str() string {
 
 	sb.WriteByte(':')
 
-	switch len(p.Network.IP) {
+	switch len(a.Network.IP) {
 	case net.IPv4len:
-		sb.WriteString(p.Network.IP.String())
+		sb.WriteString(a.Network.IP.String())
 	case net.IPv6len:
 		// add prefix for IPv4-mapped IPv6
-		if v4 := p.Network.IP.To4(); v4 != nil {
+		if v4 := a.Network.IP.To4(); v4 != nil {
 			sb.WriteString("::ffff:")
 		}
-		sb.WriteString(p.Network.IP.String())
+		sb.WriteString(a.Network.IP.String())
 	}
 
 	sb.WriteByte('/')
 
-	prefix, _ := p.Network.Mask.Size()
+	prefix, _ := a.Network.Mask.Size()
 	sb.WriteString(strconv.Itoa(prefix))
 
 	return sb.String()
@@ -1426,17 +1455,17 @@ func (a *APLPrefix) equals(b *APLPrefix) bool {
 }
 
 // copy returns a copy of the APL prefix.
-func (p *APLPrefix) copy() APLPrefix {
+func (a *APLPrefix) copy() APLPrefix {
 	return APLPrefix{
-		Negation: p.Negation,
-		Network:  copyNet(p.Network),
+		Negation: a.Negation,
+		Network:  copyNet(a.Network),
 	}
 }
 
 // len returns size of the prefix in wire format.
-func (p *APLPrefix) len() int {
+func (a *APLPrefix) len() int {
 	// 4-byte header and the network address prefix (see Section 4 of RFC 3123)
-	prefix, _ := p.Network.Mask.Size()
+	prefix, _ := a.Network.Mask.Size()
 	return 4 + (prefix+7)/8
 }
 
@@ -1469,7 +1498,7 @@ func StringToTime(s string) (uint32, error) {
 
 // saltToString converts a NSECX salt to uppercase and returns "-" when it is empty.
 func saltToString(s string) string {
-	if len(s) == 0 {
+	if s == "" {
 		return "-"
 	}
 	return strings.ToUpper(s)

+ 21 - 5
vendor/github.com/miekg/dns/types_generate.go

@@ -72,6 +72,9 @@ func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
 	if !ok {
 		return nil, false
 	}
+	if st.NumFields() == 0 {
+		return nil, false
+	}
 	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
 		return st, false
 	}
@@ -181,6 +184,8 @@ func main() {
 					o("for _, x := range rr.%s { l += len(x) + 1 }\n")
 				case `dns:"apl"`:
 					o("for _, x := range rr.%s { l += x.len() }\n")
+				case `dns:"pairs"`:
+					o("for _, x := range rr.%s { l += 4 + int(x.len()) }\n")
 				default:
 					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
 				}
@@ -241,11 +246,15 @@ func main() {
 	for _, name := range namedTypes {
 		o := scope.Lookup(name)
 		st, isEmbedded := getTypeStruct(o.Type(), scope)
+		fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
+		fields := make([]string, 0, st.NumFields())
 		if isEmbedded {
-			continue
+			a, _ := o.Type().Underlying().(*types.Struct)
+			parent := a.Field(0).Name()
+			fields = append(fields, "*rr."+parent+".copy().(*"+parent+")")
+			goto WriteCopy
 		}
-		fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
-		fields := []string{"rr.Hdr"}
+		fields = append(fields, "rr.Hdr")
 		for i := 1; i < st.NumFields(); i++ {
 			f := st.Field(i).Name()
 			if sl, ok := st.Field(i).Type().(*types.Slice); ok {
@@ -263,8 +272,14 @@ func main() {
 					continue
 				}
 				if t == "APLPrefix" {
-					fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i := range rr.%s {\n %s[i] = rr.%s[i].copy()\n}\n",
-						f, t, f, f, f, f)
+					fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n",
+						f, t, f, f, f)
+					fields = append(fields, f)
+					continue
+				}
+				if t == "SVCBKeyValue" {
+					fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n",
+						f, t, f, f, f)
 					fields = append(fields, f)
 					continue
 				}
@@ -279,6 +294,7 @@ func main() {
 			}
 			fields = append(fields, "rr."+f)
 		}
+	WriteCopy:
 		fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ","))
 		fmt.Fprintf(b, "}\n")
 	}

+ 3 - 1
vendor/github.com/miekg/dns/update.go

@@ -32,7 +32,9 @@ func (u *Msg) Used(rr []RR) {
 		u.Answer = make([]RR, 0, len(rr))
 	}
 	for _, r := range rr {
-		r.Header().Class = u.Question[0].Qclass
+		hdr := r.Header()
+		hdr.Class = u.Question[0].Qclass
+		hdr.Ttl = 0
 		u.Answer = append(u.Answer, r)
 	}
 }

+ 1 - 1
vendor/github.com/miekg/dns/version.go

@@ -3,7 +3,7 @@ package dns
 import "fmt"
 
 // Version is current version of this library.
-var Version = v{1, 1, 29}
+var Version = v{1, 1, 43}
 
 // v holds the version of this library.
 type v struct {

+ 183 - 0
vendor/github.com/miekg/dns/zduplicate.go

@@ -104,6 +104,48 @@ func (r1 *CAA) isDuplicate(_r2 RR) bool {
 	return true
 }
 
+func (r1 *CDNSKEY) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*CDNSKEY)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Flags != r2.Flags {
+		return false
+	}
+	if r1.Protocol != r2.Protocol {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.PublicKey != r2.PublicKey {
+		return false
+	}
+	return true
+}
+
+func (r1 *CDS) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*CDS)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.KeyTag != r2.KeyTag {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.DigestType != r2.DigestType {
+		return false
+	}
+	if r1.Digest != r2.Digest {
+		return false
+	}
+	return true
+}
+
 func (r1 *CERT) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*CERT)
 	if !ok {
@@ -172,6 +214,27 @@ func (r1 *DHCID) isDuplicate(_r2 RR) bool {
 	return true
 }
 
+func (r1 *DLV) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*DLV)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.KeyTag != r2.KeyTag {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.DigestType != r2.DigestType {
+		return false
+	}
+	if r1.Digest != r2.Digest {
+		return false
+	}
+	return true
+}
+
 func (r1 *DNAME) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*DNAME)
 	if !ok {
@@ -339,6 +402,48 @@ func (r1 *HIP) isDuplicate(_r2 RR) bool {
 	return true
 }
 
+func (r1 *HTTPS) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*HTTPS)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Priority != r2.Priority {
+		return false
+	}
+	if !isDuplicateName(r1.Target, r2.Target) {
+		return false
+	}
+	if len(r1.Value) != len(r2.Value) {
+		return false
+	}
+	if !areSVCBPairArraysEqual(r1.Value, r2.Value) {
+		return false
+	}
+	return true
+}
+
+func (r1 *KEY) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*KEY)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Flags != r2.Flags {
+		return false
+	}
+	if r1.Protocol != r2.Protocol {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.PublicKey != r2.PublicKey {
+		return false
+	}
+	return true
+}
+
 func (r1 *KX) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*KX)
 	if !ok {
@@ -849,6 +954,42 @@ func (r1 *RT) isDuplicate(_r2 RR) bool {
 	return true
 }
 
+func (r1 *SIG) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*SIG)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.TypeCovered != r2.TypeCovered {
+		return false
+	}
+	if r1.Algorithm != r2.Algorithm {
+		return false
+	}
+	if r1.Labels != r2.Labels {
+		return false
+	}
+	if r1.OrigTtl != r2.OrigTtl {
+		return false
+	}
+	if r1.Expiration != r2.Expiration {
+		return false
+	}
+	if r1.Inception != r2.Inception {
+		return false
+	}
+	if r1.KeyTag != r2.KeyTag {
+		return false
+	}
+	if !isDuplicateName(r1.SignerName, r2.SignerName) {
+		return false
+	}
+	if r1.Signature != r2.Signature {
+		return false
+	}
+	return true
+}
+
 func (r1 *SMIMEA) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*SMIMEA)
 	if !ok {
@@ -956,6 +1097,27 @@ func (r1 *SSHFP) isDuplicate(_r2 RR) bool {
 	return true
 }
 
+func (r1 *SVCB) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*SVCB)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Priority != r2.Priority {
+		return false
+	}
+	if !isDuplicateName(r1.Target, r2.Target) {
+		return false
+	}
+	if len(r1.Value) != len(r2.Value) {
+		return false
+	}
+	if !areSVCBPairArraysEqual(r1.Value, r2.Value) {
+		return false
+	}
+	return true
+}
+
 func (r1 *TA) isDuplicate(_r2 RR) bool {
 	r2, ok := _r2.(*TA)
 	if !ok {
@@ -1155,3 +1317,24 @@ func (r1 *X25) isDuplicate(_r2 RR) bool {
 	}
 	return true
 }
+
+func (r1 *ZONEMD) isDuplicate(_r2 RR) bool {
+	r2, ok := _r2.(*ZONEMD)
+	if !ok {
+		return false
+	}
+	_ = r2
+	if r1.Serial != r2.Serial {
+		return false
+	}
+	if r1.Scheme != r2.Scheme {
+		return false
+	}
+	if r1.Hash != r2.Hash {
+		return false
+	}
+	if r1.Digest != r2.Digest {
+		return false
+	}
+	return true
+}

+ 134 - 0
vendor/github.com/miekg/dns/zmsg.go

@@ -316,6 +316,22 @@ func (rr *HIP) pack(msg []byte, off int, compression compressionMap, compress bo
 	return off, nil
 }
 
+func (rr *HTTPS) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+	off, err = packUint16(rr.Priority, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDomainName(rr.Target, msg, off, compression, false)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDataSVCB(rr.Value, msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *KEY) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 	off, err = packUint16(rr.Flags, msg, off)
 	if err != nil {
@@ -906,6 +922,22 @@ func (rr *SSHFP) pack(msg []byte, off int, compression compressionMap, compress
 	return off, nil
 }
 
+func (rr *SVCB) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+	off, err = packUint16(rr.Priority, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDomainName(rr.Target, msg, off, compression, false)
+	if err != nil {
+		return off, err
+	}
+	off, err = packDataSVCB(rr.Value, msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *TA) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 	off, err = packUint16(rr.KeyTag, msg, off)
 	if err != nil {
@@ -1086,6 +1118,26 @@ func (rr *X25) pack(msg []byte, off int, compression compressionMap, compress bo
 	return off, nil
 }
 
+func (rr *ZONEMD) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+	off, err = packUint32(rr.Serial, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packUint8(rr.Scheme, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packUint8(rr.Hash, msg, off)
+	if err != nil {
+		return off, err
+	}
+	off, err = packStringHex(rr.Digest, msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 // unpack*() functions
 
 func (rr *A) unpack(msg []byte, off int) (off1 int, err error) {
@@ -1559,6 +1611,31 @@ func (rr *HIP) unpack(msg []byte, off int) (off1 int, err error) {
 	return off, nil
 }
 
+func (rr *HTTPS) unpack(msg []byte, off int) (off1 int, err error) {
+	rdStart := off
+	_ = rdStart
+
+	rr.Priority, off, err = unpackUint16(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Target, off, err = UnpackDomainName(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Value, off, err = unpackDataSVCB(msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *KEY) unpack(msg []byte, off int) (off1 int, err error) {
 	rdStart := off
 	_ = rdStart
@@ -2461,6 +2538,31 @@ func (rr *SSHFP) unpack(msg []byte, off int) (off1 int, err error) {
 	return off, nil
 }
 
+func (rr *SVCB) unpack(msg []byte, off int) (off1 int, err error) {
+	rdStart := off
+	_ = rdStart
+
+	rr.Priority, off, err = unpackUint16(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Target, off, err = UnpackDomainName(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Value, off, err = unpackDataSVCB(msg, off)
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}
+
 func (rr *TA) unpack(msg []byte, off int) (off1 int, err error) {
 	rdStart := off
 	_ = rdStart
@@ -2739,3 +2841,35 @@ func (rr *X25) unpack(msg []byte, off int) (off1 int, err error) {
 	}
 	return off, nil
 }
+
+func (rr *ZONEMD) unpack(msg []byte, off int) (off1 int, err error) {
+	rdStart := off
+	_ = rdStart
+
+	rr.Serial, off, err = unpackUint32(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Scheme, off, err = unpackUint8(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Hash, off, err = unpackUint8(msg, off)
+	if err != nil {
+		return off, err
+	}
+	if off == len(msg) {
+		return off, nil
+	}
+	rr.Digest, off, err = unpackStringHex(msg, off, rdStart+int(rr.Hdr.Rdlength))
+	if err != nil {
+		return off, err
+	}
+	return off, nil
+}

+ 56 - 2
vendor/github.com/miekg/dns/ztypes.go

@@ -33,6 +33,7 @@ var TypeToRR = map[uint16]func() RR{
 	TypeGPOS:       func() RR { return new(GPOS) },
 	TypeHINFO:      func() RR { return new(HINFO) },
 	TypeHIP:        func() RR { return new(HIP) },
+	TypeHTTPS:      func() RR { return new(HTTPS) },
 	TypeKEY:        func() RR { return new(KEY) },
 	TypeKX:         func() RR { return new(KX) },
 	TypeL32:        func() RR { return new(L32) },
@@ -70,6 +71,7 @@ var TypeToRR = map[uint16]func() RR{
 	TypeSPF:        func() RR { return new(SPF) },
 	TypeSRV:        func() RR { return new(SRV) },
 	TypeSSHFP:      func() RR { return new(SSHFP) },
+	TypeSVCB:       func() RR { return new(SVCB) },
 	TypeTA:         func() RR { return new(TA) },
 	TypeTALINK:     func() RR { return new(TALINK) },
 	TypeTKEY:       func() RR { return new(TKEY) },
@@ -80,6 +82,7 @@ var TypeToRR = map[uint16]func() RR{
 	TypeUINFO:      func() RR { return new(UINFO) },
 	TypeURI:        func() RR { return new(URI) },
 	TypeX25:        func() RR { return new(X25) },
+	TypeZONEMD:     func() RR { return new(ZONEMD) },
 }
 
 // TypeToString is a map of strings for each RR type.
@@ -110,6 +113,7 @@ var TypeToString = map[uint16]string{
 	TypeGPOS:       "GPOS",
 	TypeHINFO:      "HINFO",
 	TypeHIP:        "HIP",
+	TypeHTTPS:      "HTTPS",
 	TypeISDN:       "ISDN",
 	TypeIXFR:       "IXFR",
 	TypeKEY:        "KEY",
@@ -153,6 +157,7 @@ var TypeToString = map[uint16]string{
 	TypeSPF:        "SPF",
 	TypeSRV:        "SRV",
 	TypeSSHFP:      "SSHFP",
+	TypeSVCB:       "SVCB",
 	TypeTA:         "TA",
 	TypeTALINK:     "TALINK",
 	TypeTKEY:       "TKEY",
@@ -164,6 +169,7 @@ var TypeToString = map[uint16]string{
 	TypeUNSPEC:     "UNSPEC",
 	TypeURI:        "URI",
 	TypeX25:        "X25",
+	TypeZONEMD:     "ZONEMD",
 	TypeNSAPPTR:    "NSAP-PTR",
 }
 
@@ -191,6 +197,7 @@ func (rr *GID) Header() *RR_Header        { return &rr.Hdr }
 func (rr *GPOS) Header() *RR_Header       { return &rr.Hdr }
 func (rr *HINFO) Header() *RR_Header      { return &rr.Hdr }
 func (rr *HIP) Header() *RR_Header        { return &rr.Hdr }
+func (rr *HTTPS) Header() *RR_Header      { return &rr.Hdr }
 func (rr *KEY) Header() *RR_Header        { return &rr.Hdr }
 func (rr *KX) Header() *RR_Header         { return &rr.Hdr }
 func (rr *L32) Header() *RR_Header        { return &rr.Hdr }
@@ -229,6 +236,7 @@ func (rr *SOA) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SPF) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SRV) Header() *RR_Header        { return &rr.Hdr }
 func (rr *SSHFP) Header() *RR_Header      { return &rr.Hdr }
+func (rr *SVCB) Header() *RR_Header       { return &rr.Hdr }
 func (rr *TA) Header() *RR_Header         { return &rr.Hdr }
 func (rr *TALINK) Header() *RR_Header     { return &rr.Hdr }
 func (rr *TKEY) Header() *RR_Header       { return &rr.Hdr }
@@ -239,6 +247,7 @@ func (rr *UID) Header() *RR_Header        { return &rr.Hdr }
 func (rr *UINFO) Header() *RR_Header      { return &rr.Hdr }
 func (rr *URI) Header() *RR_Header        { return &rr.Hdr }
 func (rr *X25) Header() *RR_Header        { return &rr.Hdr }
+func (rr *ZONEMD) Header() *RR_Header     { return &rr.Hdr }
 
 // len() functions
 func (rr *A) len(off int, compression map[string]struct{}) int {
@@ -592,6 +601,15 @@ func (rr *SSHFP) len(off int, compression map[string]struct{}) int {
 	l += len(rr.FingerPrint) / 2
 	return l
 }
+func (rr *SVCB) len(off int, compression map[string]struct{}) int {
+	l := rr.Hdr.len(off, compression)
+	l += 2 // Priority
+	l += domainNameLen(rr.Target, off+l, compression, false)
+	for _, x := range rr.Value {
+		l += 4 + int(x.len())
+	}
+	return l
+}
 func (rr *TA) len(off int, compression map[string]struct{}) int {
 	l := rr.Hdr.len(off, compression)
 	l += 2 // KeyTag
@@ -669,6 +687,14 @@ func (rr *X25) len(off int, compression map[string]struct{}) int {
 	l += len(rr.PSDNAddress) + 1
 	return l
 }
+func (rr *ZONEMD) len(off int, compression map[string]struct{}) int {
+	l := rr.Hdr.len(off, compression)
+	l += 4 // Serial
+	l++    // Scheme
+	l++    // Hash
+	l += len(rr.Digest) / 2
+	return l
+}
 
 // copy() functions
 func (rr *A) copy() RR {
@@ -685,8 +711,8 @@ func (rr *ANY) copy() RR {
 }
 func (rr *APL) copy() RR {
 	Prefixes := make([]APLPrefix, len(rr.Prefixes))
-	for i := range rr.Prefixes {
-		Prefixes[i] = rr.Prefixes[i].copy()
+	for i, e := range rr.Prefixes {
+		Prefixes[i] = e.copy()
 	}
 	return &APL{rr.Hdr, Prefixes}
 }
@@ -698,6 +724,12 @@ func (rr *AVC) copy() RR {
 func (rr *CAA) copy() RR {
 	return &CAA{rr.Hdr, rr.Flag, rr.Tag, rr.Value}
 }
+func (rr *CDNSKEY) copy() RR {
+	return &CDNSKEY{*rr.DNSKEY.copy().(*DNSKEY)}
+}
+func (rr *CDS) copy() RR {
+	return &CDS{*rr.DS.copy().(*DS)}
+}
 func (rr *CERT) copy() RR {
 	return &CERT{rr.Hdr, rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate}
 }
@@ -712,6 +744,9 @@ func (rr *CSYNC) copy() RR {
 func (rr *DHCID) copy() RR {
 	return &DHCID{rr.Hdr, rr.Digest}
 }
+func (rr *DLV) copy() RR {
+	return &DLV{*rr.DS.copy().(*DS)}
+}
 func (rr *DNAME) copy() RR {
 	return &DNAME{rr.Hdr, rr.Target}
 }
@@ -744,6 +779,12 @@ func (rr *HIP) copy() RR {
 	copy(RendezvousServers, rr.RendezvousServers)
 	return &HIP{rr.Hdr, rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers}
 }
+func (rr *HTTPS) copy() RR {
+	return &HTTPS{*rr.SVCB.copy().(*SVCB)}
+}
+func (rr *KEY) copy() RR {
+	return &KEY{*rr.DNSKEY.copy().(*DNSKEY)}
+}
 func (rr *KX) copy() RR {
 	return &KX{rr.Hdr, rr.Preference, rr.Exchanger}
 }
@@ -847,6 +888,9 @@ func (rr *RRSIG) copy() RR {
 func (rr *RT) copy() RR {
 	return &RT{rr.Hdr, rr.Preference, rr.Host}
 }
+func (rr *SIG) copy() RR {
+	return &SIG{*rr.RRSIG.copy().(*RRSIG)}
+}
 func (rr *SMIMEA) copy() RR {
 	return &SMIMEA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
 }
@@ -864,6 +908,13 @@ func (rr *SRV) copy() RR {
 func (rr *SSHFP) copy() RR {
 	return &SSHFP{rr.Hdr, rr.Algorithm, rr.Type, rr.FingerPrint}
 }
+func (rr *SVCB) copy() RR {
+	Value := make([]SVCBKeyValue, len(rr.Value))
+	for i, e := range rr.Value {
+		Value[i] = e.copy()
+	}
+	return &SVCB{rr.Hdr, rr.Priority, rr.Target, Value}
+}
 func (rr *TA) copy() RR {
 	return &TA{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
 }
@@ -896,3 +947,6 @@ func (rr *URI) copy() RR {
 func (rr *X25) copy() RR {
 	return &X25{rr.Hdr, rr.PSDNAddress}
 }
+func (rr *ZONEMD) copy() RR {
+	return &ZONEMD{rr.Hdr, rr.Serial, rr.Scheme, rr.Hash, rr.Digest}
+}

+ 20 - 0
vendor/github.com/refraction-networking/utls/README.md

@@ -102,6 +102,26 @@ you can set UConn.HandshakeStateBuilt = true, and marshal clientHello into UConn
 In this case you will be responsible for modifying other parts of Config and ClientHelloMsg to reflect your setup
 and not confuse "crypto/tls", which will be processing response from server.
 
+### Fingerprinting Captured Client Hello
+You can use a captured client hello to generate new ones that mimic/have the same properties as the original.
+The generated client hellos _should_ look like they were generated from the same client software as the original fingerprinted bytes.
+In order to do this:
+1) Create a `ClientHelloSpec` from the raw bytes of the original client hello
+2) Use `HelloCustom` as an argument for `UClient()` to get empty config
+3) Use `ApplyPreset` with the generated `ClientHelloSpec` to set the appropriate connection properties
+```
+uConn := UClient(&net.TCPConn{}, nil, HelloCustom)
+fingerprinter := &Fingerprinter{}
+generatedSpec, err := fingerprinter.FingerprintClientHello(rawCapturedClientHelloBytes)
+if err != nil {
+  panic("fingerprinting failed: %v", err)
+}
+if err := uConn.ApplyPreset(generatedSpec); err != nil {
+  panic("applying generated spec failed: %v", err)
+}
+```
+The `rawCapturedClientHelloBytes` should be the full tls record, including the record type/version/length header.
+
 ## Roller
 
 A simple wrapper, that allows to easily use multiple latest(auto-updated) fingerprints.

+ 5 - 1
vendor/github.com/refraction-networking/utls/key_agreement.go

@@ -69,7 +69,11 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello
 		return nil, nil, err
 	}
 
-	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
+	rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
+	if !ok {
+		return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
+	}
+	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
 	if err != nil {
 		return nil, nil, err
 	}

+ 16 - 1
vendor/github.com/refraction-networking/utls/ticket.go

@@ -171,7 +171,20 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
 	return encrypted, nil
 }
 
+// [uTLS] changed to use exported DecryptTicketWith func below
 func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
+	tks := ticketKeys(c.config.ticketKeys()).ToPublic()
+	return DecryptTicketWith(encrypted, tks)
+}
+
+// DecryptTicketWith decrypts an encrypted session ticket
+// using a TicketKeys (ie []TicketKey) struct
+//
+// usedOldKey will be true if the key used for decryption is
+// not the first in the []TicketKey slice
+//
+// [uTLS] changed to be made public and take a TicketKeys instead of use a Conn receiver
+func DecryptTicketWith(encrypted []byte, tks TicketKeys) (plaintext []byte, usedOldKey bool) {
 	if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
 		return nil, false
 	}
@@ -181,7 +194,9 @@ func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey boo
 	macBytes := encrypted[len(encrypted)-sha256.Size:]
 	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
 
-	keys := c.config.ticketKeys()
+	// keys := c.config.ticketKeys() // [uTLS] keys are received as a function argument
+
+	keys := tks.ToPrivate()
 	keyIndex := -1
 	for i, candidateKey := range keys {
 		if bytes.Equal(keyName, candidateKey.keyName[:]) {

+ 15 - 0
vendor/github.com/refraction-networking/utls/u_common.go

@@ -39,6 +39,7 @@ const (
 
 	FAKE_TLS_DHE_RSA_WITH_AES_128_CBC_SHA  = uint16(0x0033)
 	FAKE_TLS_DHE_RSA_WITH_AES_256_CBC_SHA  = uint16(0x0039)
+	FAKE_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384  = uint16(0x009f)
 	FAKE_TLS_RSA_WITH_RC4_128_MD5          = uint16(0x0004)
 	FAKE_TLS_EMPTY_RENEGOTIATION_INFO_SCSV = uint16(0x00ff)
 )
@@ -161,6 +162,20 @@ var (
 // https://tools.ietf.org/html/draft-ietf-tls-grease-01
 const GREASE_PLACEHOLDER = 0x0a0a
 
+func isGREASEUint16(v uint16) bool {
+	// First byte is same as second byte
+	// and lowest nibble is 0xa
+	return ((v >> 8) == v&0xff) && v&0xf == 0xa
+}
+
+func unGREASEUint16(v uint16) uint16 {
+	if isGREASEUint16(v) {
+		return GREASE_PLACEHOLDER
+	} else {
+		return v
+	}
+}
+
 // utlsMacSHA384 returns a SHA-384 based MAC. These are only supported in TLS 1.2
 // so the given version is ignored.
 func utlsMacSHA384(version uint16, key []byte) macFunction {

+ 360 - 0
vendor/github.com/refraction-networking/utls/u_fingerprinter.go

@@ -0,0 +1,360 @@
+// Copyright 2017 Google Inc. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+	"errors"
+	"fmt"
+	"strings"
+
+	"golang.org/x/crypto/cryptobyte"
+)
+
+// Fingerprinter is a struct largely for holding options for the FingerprintClientHello func
+type Fingerprinter struct {
+	// KeepPSK will ensure that the PreSharedKey extension is passed along into the resulting ClientHelloSpec as-is
+	KeepPSK bool
+	// AllowBluntMimicry will ensure that unknown extensions are
+	// passed along into the resulting ClientHelloSpec as-is
+	// It will not ensure that the PSK is passed along, if you require that, use KeepPSK
+	// WARNING: there could be numerous subtle issues with ClientHelloSpecs
+	// that are generated with this flag which could compromise security and/or mimicry
+	AllowBluntMimicry bool
+	// AlwaysAddPadding will always add a UtlsPaddingExtension with BoringPaddingStyle
+	// at the end of the extensions list if it isn't found in the fingerprinted hello.
+	// This could be useful in scenarios where the hello you are fingerprinting does not
+	// have any padding, but you suspect that other changes you make to the final hello
+	// (including things like different SNI lengths) would cause padding to be necessary
+	AlwaysAddPadding bool
+}
+
+// FingerprintClientHello returns a ClientHelloSpec which is based on the
+// ClientHello that is passed in as the data argument
+//
+// If the ClientHello passed in has extensions that are not recognized or cannot be handled
+// it will return a non-nil error and a nil *ClientHelloSpec value
+//
+// The data should be the full tls record, including the record type/version/length header
+// as well as the handshake type/length/version header
+// https://tools.ietf.org/html/rfc5246#section-6.2
+// https://tools.ietf.org/html/rfc5246#section-7.4
+func (f *Fingerprinter) FingerprintClientHello(data []byte) (*ClientHelloSpec, error) {
+	clientHelloSpec := &ClientHelloSpec{}
+	s := cryptobyte.String(data)
+
+	var contentType uint8
+	var recordVersion uint16
+	if !s.ReadUint8(&contentType) || // record type
+		!s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
+		return nil, errors.New("unable to read record type, version, and length")
+	}
+
+	if recordType(contentType) != recordTypeHandshake {
+		return nil, errors.New("record is not a handshake")
+	}
+
+	var handshakeVersion uint16
+	var handshakeType uint8
+
+	if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
+		!s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
+		return nil, errors.New("unable to read handshake message type, length, and random")
+	}
+
+	if handshakeType != typeClientHello {
+		return nil, errors.New("handshake message is not a ClientHello")
+	}
+
+	clientHelloSpec.TLSVersMin = recordVersion
+	clientHelloSpec.TLSVersMax = handshakeVersion
+
+	var ignoredSessionID cryptobyte.String
+	if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
+		return nil, errors.New("unable to read session id")
+	}
+
+	var cipherSuitesBytes cryptobyte.String
+	if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
+		return nil, errors.New("unable to read ciphersuites")
+	}
+	cipherSuites := []uint16{}
+	for !cipherSuitesBytes.Empty() {
+		var suite uint16
+		if !cipherSuitesBytes.ReadUint16(&suite) {
+			return nil, errors.New("unable to read ciphersuite")
+		}
+		cipherSuites = append(cipherSuites, unGREASEUint16(suite))
+	}
+	clientHelloSpec.CipherSuites = cipherSuites
+
+	if !readUint8LengthPrefixed(&s, &clientHelloSpec.CompressionMethods) {
+		return nil, errors.New("unable to read compression methods")
+	}
+
+	if s.Empty() {
+		// ClientHello is optionally followed by extension data
+		return clientHelloSpec, nil
+	}
+
+	var extensions cryptobyte.String
+	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
+		return nil, errors.New("unable to read extensions data")
+	}
+
+	for !extensions.Empty() {
+		var extension uint16
+		var extData cryptobyte.String
+		if !extensions.ReadUint16(&extension) ||
+			!extensions.ReadUint16LengthPrefixed(&extData) {
+			return nil, errors.New("unable to read extension data")
+		}
+
+		switch extension {
+		case extensionServerName:
+			// RFC 6066, Section 3
+			var nameList cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
+				return nil, errors.New("unable to read server name extension data")
+			}
+			var serverName string
+			for !nameList.Empty() {
+				var nameType uint8
+				var serverNameBytes cryptobyte.String
+				if !nameList.ReadUint8(&nameType) ||
+					!nameList.ReadUint16LengthPrefixed(&serverNameBytes) ||
+					serverNameBytes.Empty() {
+					return nil, errors.New("unable to read server name extension data")
+				}
+				if nameType != 0 {
+					continue
+				}
+				if len(serverName) != 0 {
+					return nil, errors.New("multiple names of the same name_type in server name extension are prohibited")
+				}
+				serverName = string(serverNameBytes)
+				if strings.HasSuffix(serverName, ".") {
+					return nil, errors.New("SNI value may not include a trailing dot")
+				}
+
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SNIExtension{})
+
+			}
+		case extensionNextProtoNeg:
+			// draft-agl-tls-nextprotoneg-04
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &NPNExtension{})
+
+		case extensionStatusRequest:
+			// RFC 4366, Section 3.6
+			var statusType uint8
+			var ignored cryptobyte.String
+			if !extData.ReadUint8(&statusType) ||
+				!extData.ReadUint16LengthPrefixed(&ignored) ||
+				!extData.ReadUint16LengthPrefixed(&ignored) {
+				return nil, errors.New("unable to read status request extension data")
+			}
+
+			if statusType == statusTypeOCSP {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &StatusRequestExtension{})
+			} else {
+				return nil, errors.New("status request extension statusType is not statusTypeOCSP")
+			}
+
+		case extensionSupportedCurves:
+			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
+			var curvesBytes cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&curvesBytes) || curvesBytes.Empty() {
+				return nil, errors.New("unable to read supported curves extension data")
+			}
+			curves := []CurveID{}
+			for !curvesBytes.Empty() {
+				var curve uint16
+				if !curvesBytes.ReadUint16(&curve) {
+					return nil, errors.New("unable to read supported curves extension data")
+				}
+				curves = append(curves, CurveID(unGREASEUint16(curve)))
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedCurvesExtension{curves})
+
+		case extensionSupportedPoints:
+			// RFC 4492, Section 5.1.2
+			supportedPoints := []uint8{}
+			if !readUint8LengthPrefixed(&extData, &supportedPoints) ||
+				len(supportedPoints) == 0 {
+				return nil, errors.New("unable to read supported points extension data")
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedPointsExtension{supportedPoints})
+
+		case extensionSessionTicket:
+			// RFC 5077, Section 3.2
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SessionTicketExtension{})
+
+		case extensionSignatureAlgorithms:
+			// RFC 5246, Section 7.4.1.4.1
+			var sigAndAlgs cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
+				return nil, errors.New("unable to read signature algorithms extension data")
+			}
+			supportedSignatureAlgorithms := []SignatureScheme{}
+			for !sigAndAlgs.Empty() {
+				var sigAndAlg uint16
+				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+					return nil, errors.New("unable to read signature algorithms extension data")
+				}
+				supportedSignatureAlgorithms = append(
+					supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SignatureAlgorithmsExtension{supportedSignatureAlgorithms})
+
+		case extensionSignatureAlgorithmsCert:
+			// RFC 8446, Section 4.2.3
+			if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension SignatureAlgorithmsCert")
+			}
+
+		case extensionRenegotiationInfo:
+			// RFC 5746, Section 3.2
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &RenegotiationInfoExtension{RenegotiateOnceAsClient})
+
+		case extensionALPN:
+			// RFC 7301, Section 3.1
+			var protoList cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
+				return nil, errors.New("unable to read ALPN extension data")
+			}
+			alpnProtocols := []string{}
+			for !protoList.Empty() {
+				var proto cryptobyte.String
+				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
+					return nil, errors.New("unable to read ALPN extension data")
+				}
+				alpnProtocols = append(alpnProtocols, string(proto))
+
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &ALPNExtension{alpnProtocols})
+
+		case extensionSCT:
+			// RFC 6962, Section 3.3.1
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SCTExtension{})
+
+		case extensionSupportedVersions:
+			// RFC 8446, Section 4.2.1
+			var versList cryptobyte.String
+			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
+				return nil, errors.New("unable to read supported versions extension data")
+			}
+			supportedVersions := []uint16{}
+			for !versList.Empty() {
+				var vers uint16
+				if !versList.ReadUint16(&vers) {
+					return nil, errors.New("unable to read supported versions extension data")
+				}
+				supportedVersions = append(supportedVersions, unGREASEUint16(vers))
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedVersionsExtension{supportedVersions})
+			// If SupportedVersionsExtension is present, use that instead of record+handshake versions
+			clientHelloSpec.TLSVersMin = 0
+			clientHelloSpec.TLSVersMax = 0
+
+		case extensionKeyShare:
+			// RFC 8446, Section 4.2.8
+			var clientShares cryptobyte.String
+			if !extData.ReadUint16LengthPrefixed(&clientShares) {
+				return nil, errors.New("unable to read key share extension data")
+			}
+			keyShares := []KeyShare{}
+			for !clientShares.Empty() {
+				var ks KeyShare
+				var group uint16
+				if !clientShares.ReadUint16(&group) ||
+					!readUint16LengthPrefixed(&clientShares, &ks.Data) ||
+					len(ks.Data) == 0 {
+					return nil, errors.New("unable to read key share extension data")
+				}
+				ks.Group = CurveID(unGREASEUint16(group))
+				// if not GREASE, key share data will be discarded as it should
+				// be generated per connection
+				if ks.Group != GREASE_PLACEHOLDER {
+					ks.Data = nil
+				}
+				keyShares = append(keyShares, ks)
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &KeyShareExtension{keyShares})
+
+		case extensionPSKModes:
+			// RFC 8446, Section 4.2.9
+			// TODO: PSK Modes have their own form of GREASE-ing which is not currently implemented
+			// the current functionality will NOT re-GREASE/re-randomize these values when using a fingerprinted spec
+			// https://github.com/refraction-networking/utls/pull/58#discussion_r522354105
+			// https://tools.ietf.org/html/draft-ietf-tls-grease-01#section-2
+			pskModes := []uint8{}
+			if !readUint8LengthPrefixed(&extData, &pskModes) {
+				return nil, errors.New("unable to read PSK extension data")
+			}
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &PSKKeyExchangeModesExtension{pskModes})
+
+		case utlsExtensionExtendedMasterSecret:
+			// https://tools.ietf.org/html/rfc7627
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsExtendedMasterSecretExtension{})
+
+		case utlsExtensionPadding:
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
+
+		case fakeExtensionChannelID, fakeCertCompressionAlgs, fakeRecordSizeLimit:
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+
+		case extensionPreSharedKey:
+			// RFC 8446, Section 4.2.11
+			if f.KeepPSK {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension PreSharedKey")
+			}
+
+		case extensionCookie:
+			// RFC 8446, Section 4.2.2
+			if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension Cookie")
+			}
+
+		case extensionEarlyData:
+			// RFC 8446, Section 4.2.10
+			if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, errors.New("unsupported extension EarlyData")
+			}
+
+		default:
+			if isGREASEUint16(extension) {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsGREASEExtension{unGREASEUint16(extension), extData})
+			} else if f.AllowBluntMimicry {
+				clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
+			} else {
+				return nil, fmt.Errorf("unsupported extension %#x", extension)
+			}
+
+			continue
+		}
+	}
+
+	if f.AlwaysAddPadding {
+		alreadyHasPadding := false
+		for _, ext := range clientHelloSpec.Extensions {
+			if _, ok := ext.(*UtlsPaddingExtension); ok {
+				alreadyHasPadding = true
+				break
+			}
+		}
+		if !alreadyHasPadding {
+			clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
+		}
+	}
+
+	return clientHelloSpec, nil
+}

+ 10 - 0
vendor/github.com/refraction-networking/utls/u_parrots.go

@@ -617,6 +617,9 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
 	uconn.Extensions = make([]TLSExtension, len(p.Extensions))
 	copy(uconn.Extensions, p.Extensions)
 
+	// Check whether NPN extension actually exists
+	var haveNPN bool
+
 	// reGrease, and point things to each other
 	for _, e := range uconn.Extensions {
 		switch ext := e.(type) {
@@ -681,8 +684,15 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
 					ext.Versions[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_version)
 				}
 			}
+		case *NPNExtension:
+			haveNPN = true
 		}
 	}
+
+	// The default golang behavior in makeClientHello always sets NextProtoNeg if NextProtos is set,
+	// but NextProtos is also used by ALPN and our spec nmay not actually have a NPN extension
+	hello.NextProtoNeg = haveNPN
+
 	return nil
 }
 

+ 59 - 0
vendor/github.com/refraction-networking/utls/u_public.go

@@ -421,6 +421,16 @@ func (chm *clientHelloMsg) getPublicPtr() *ClientHelloMsg {
 	}
 }
 
+// UnmarshalClientHello allows external code to parse raw client hellos.
+// It returns nil on failure.
+func UnmarshalClientHello(data []byte) *ClientHelloMsg {
+	m := &clientHelloMsg{}
+	if m.unmarshal(data) {
+		return m.getPublicPtr()
+	}
+	return nil
+}
+
 // A CipherSuite is a specific combination of key agreement, cipher and MAC
 // function. All cipher suites currently assume RSA key agreement.
 type CipherSuite struct {
@@ -612,3 +622,52 @@ func (css *ClientSessionState) SetServerCertificates(ServerCertificates []*x509.
 func (css *ClientSessionState) SetVerifiedChains(VerifiedChains [][]*x509.Certificate) {
 	css.verifiedChains = VerifiedChains
 }
+
+// TicketKey is the internal representation of a session ticket key.
+type TicketKey struct {
+	// KeyName is an opaque byte string that serves to identify the session
+	// ticket key. It's exposed as plaintext in every session ticket.
+	KeyName [ticketKeyNameLen]byte
+	AesKey  [16]byte
+	HmacKey [16]byte
+}
+
+type TicketKeys []TicketKey
+type ticketKeys []ticketKey
+
+func TicketKeyFromBytes(b [32]byte) TicketKey {
+	tk := ticketKeyFromBytes(b)
+	return tk.ToPublic()
+}
+
+func (tk ticketKey) ToPublic() TicketKey {
+	return TicketKey{
+		KeyName: tk.keyName,
+		AesKey:  tk.aesKey,
+		HmacKey: tk.hmacKey,
+	}
+}
+
+func (TK TicketKey) ToPrivate() ticketKey {
+	return ticketKey{
+		keyName: TK.KeyName,
+		aesKey:  TK.AesKey,
+		hmacKey: TK.HmacKey,
+	}
+}
+
+func (tks ticketKeys) ToPublic() []TicketKey {
+	var TKS []TicketKey
+	for _, ks := range tks {
+		TKS = append(TKS, ks.ToPublic())
+	}
+	return TKS
+}
+
+func (TKS TicketKeys) ToPrivate() []ticketKey {
+	var tks []ticketKey
+	for _, TK := range TKS {
+		tks = append(tks, TK.ToPrivate())
+	}
+	return tks
+}

+ 6 - 6
vendor/vendor.json

@@ -457,10 +457,10 @@
 			"revisionTime": "2020-07-28T10:15:04Z"
 		},
 		{
-			"checksumSHA1": "CqKUytzURTAPk1mUOsySyXGQft0=",
+			"checksumSHA1": "zpjq6xHytHc3aAYfkaomcDmeAek=",
 			"path": "github.com/miekg/dns",
-			"revision": "d128d10d176b810543f8fecd089082a29d3f159d",
-			"revisionTime": "2020-04-28T07:24:18Z"
+			"revision": "ab67aa64230094bdd0167ee5360e00e0a250a3ac",
+			"revisionTime": "2021-08-04T16:16:52Z"
 		},
 		{
 			"checksumSHA1": "m2L8ohfZiFRsMW3iynaH/TWgnSY=",
@@ -608,10 +608,10 @@
 			"revisionTime": "2021-06-04T20:39:09Z"
 		},
 		{
-			"checksumSHA1": "OagdWaWcbCBQZR5bBGgGaK3nddE=",
+			"checksumSHA1": "jOxlnqvKSKn1SIkA5ldRe5lxqAc=",
 			"path": "github.com/refraction-networking/utls",
-			"revision": "186025ac7b77465439618d1aeb2a5e444714d1cc",
-			"revisionTime": "2020-07-29T01:25:36Z"
+			"revision": "0b2885c8c0d4467cfe98136748a9d011d0b8fff0",
+			"revisionTime": "2021-07-13T16:56:36Z"
 		},
 		{
 			"checksumSHA1": "Fn9JW8u40ABN9Uc9wuvquuyOB+8=",