dsl.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. /*
  2. * Copyright (c) 2025, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package testutils
  20. import (
  21. "bytes"
  22. "crypto/rand"
  23. "crypto/rsa"
  24. "crypto/tls"
  25. "crypto/x509"
  26. "crypto/x509/pkix"
  27. "encoding/base64"
  28. "encoding/json"
  29. "encoding/pem"
  30. "fmt"
  31. "io"
  32. "math/big"
  33. "net"
  34. "net/http"
  35. "os"
  36. "path/filepath"
  37. "strings"
  38. "sync"
  39. "sync/atomic"
  40. "time"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
  44. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  45. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  46. )
  47. type DSLBackendTestShim interface {
  48. ClientIPHeaderName() string
  49. ClientGeoIPDataHeaderName() string
  50. ClientTunneledHeaderName() string
  51. HostIDHeaderName() string
  52. DiscoverServerEntriesRequestPath() string
  53. GetServerEntriesRequestPath() string
  54. GetActiveOSLsRequestPath() string
  55. GetOSLFileSpecsRequestPath() string
  56. UnmarshalDiscoverServerEntriesRequest(
  57. cborRequest []byte) (
  58. apiParams protocol.PackedAPIParameters,
  59. oslKeys [][]byte,
  60. discoverCount int32,
  61. retErr error)
  62. MarshalDiscoverServerEntriesResponse(
  63. versionedServerEntryTags []*struct {
  64. Tag []byte
  65. Version int32
  66. }) (
  67. cborResponse []byte,
  68. retErr error)
  69. UnmarshalGetServerEntriesRequest(
  70. cborRequest []byte) (
  71. apiParams protocol.PackedAPIParameters,
  72. serverEntryTags [][]byte,
  73. retErr error)
  74. MarshalGetServerEntriesResponse(
  75. sourcedServerEntries []*struct {
  76. ServerEntryFields protocol.PackedServerEntryFields
  77. Source string
  78. }) (
  79. cborResponse []byte,
  80. retErr error)
  81. UnmarshalGetActiveOSLsRequest(
  82. cborRequest []byte) (
  83. apiParams protocol.PackedAPIParameters,
  84. retErr error)
  85. MarshalGetActiveOSLsResponse(
  86. activeOSLIDs [][]byte) (
  87. cborResponse []byte,
  88. retErr error)
  89. UnmarshalGetOSLFileSpecsRequest(
  90. cborRequest []byte) (
  91. apiParams protocol.PackedAPIParameters,
  92. oslIDs [][]byte,
  93. retErr error)
  94. MarshalGetOSLFileSpecsResponse(
  95. oslFileSpecs [][]byte) (
  96. cborResponse []byte,
  97. retErr error)
  98. }
  99. // TestDSLBackend is a mock DSL backend intended only for testing.
  100. type TestDSLBackend struct {
  101. shim DSLBackendTestShim
  102. tlsConfig *TestDSLTLSConfig
  103. expectedClientIP string
  104. expectedClientGeoIPData *common.GeoIPData
  105. expectedHostID string
  106. oslPaveData atomic.Value
  107. untunneledServerEntries map[string]*dslSourcedServerEntry
  108. tunneledServerEntries map[string]*dslSourcedServerEntry
  109. listener net.Listener
  110. }
  111. type dslSourcedServerEntry struct {
  112. ServerEntryFields protocol.PackedServerEntryFields
  113. Source string
  114. }
  115. func NewTestDSLBackend(
  116. shim DSLBackendTestShim,
  117. tlsConfig *TestDSLTLSConfig,
  118. expectedClientIP string,
  119. expectedClientGeoIPData *common.GeoIPData,
  120. expectedHostID string,
  121. oslPaveData []*osl.PaveData) (*TestDSLBackend, error) {
  122. b := &TestDSLBackend{
  123. shim: shim,
  124. tlsConfig: tlsConfig,
  125. expectedClientIP: expectedClientIP,
  126. expectedClientGeoIPData: expectedClientGeoIPData,
  127. expectedHostID: expectedHostID,
  128. }
  129. b.oslPaveData.Store(oslPaveData)
  130. // Generate mock server entries.
  131. // Run GenerateConfig concurrently to try to take advantage of multiple
  132. // CPU cores.
  133. //
  134. // Update: no longer using server.GenerateConfig due to import cycle.
  135. var initMutex sync.Mutex
  136. var initGroup sync.WaitGroup
  137. var initErr error
  138. serverEntries := make(map[string]*dslSourcedServerEntry)
  139. for i := 1; i <= 128; i++ {
  140. initGroup.Add(1)
  141. go func(i int) (retErr error) {
  142. defer initGroup.Done()
  143. defer func() {
  144. if retErr != nil {
  145. initMutex.Lock()
  146. initErr = retErr
  147. initMutex.Unlock()
  148. }
  149. }()
  150. serverEntry := &protocol.ServerEntry{
  151. Tag: prng.Base64String(32),
  152. IpAddress: fmt.Sprintf("192.0.2.%d", i),
  153. SshUsername: prng.HexString(8),
  154. SshPassword: prng.HexString(32),
  155. SshHostKey: prng.Base64String(280),
  156. SshObfuscatedPort: prng.Range(1, 65535),
  157. SshObfuscatedKey: prng.HexString(32),
  158. Capabilities: []string{"OSSH"},
  159. Region: prng.HexString(1),
  160. ProviderID: strings.ToUpper(prng.HexString(8)),
  161. ConfigurationVersion: 0,
  162. Signature: prng.Base64String(80),
  163. }
  164. serverEntryFields, err := serverEntry.GetServerEntryFields()
  165. if err != nil {
  166. return errors.Trace(err)
  167. }
  168. packed, err := protocol.EncodePackedServerEntryFields(serverEntryFields)
  169. if err != nil {
  170. return errors.Trace(err)
  171. }
  172. source := fmt.Sprintf("DSL-compartment-%d", i)
  173. initMutex.Lock()
  174. if serverEntries[serverEntry.Tag] != nil {
  175. initMutex.Unlock()
  176. return errors.TraceNew("duplicate tag")
  177. }
  178. serverEntries[serverEntry.Tag] = &dslSourcedServerEntry{
  179. ServerEntryFields: packed,
  180. Source: source,
  181. }
  182. initMutex.Unlock()
  183. return nil
  184. }(i)
  185. }
  186. initGroup.Wait()
  187. if initErr != nil {
  188. return nil, errors.Trace(initErr)
  189. }
  190. b.untunneledServerEntries = serverEntries
  191. b.tunneledServerEntries = serverEntries
  192. return b, nil
  193. }
  194. func (b *TestDSLBackend) Start() error {
  195. logger := NewTestLoggerWithComponent("backend")
  196. listener, err := net.Listen("tcp", "127.0.0.1:0")
  197. if err != nil {
  198. return errors.Trace(err)
  199. }
  200. certificatePool := x509.NewCertPool()
  201. certificatePool.AddCert(b.tlsConfig.CACertificate)
  202. listener = tls.NewListener(
  203. listener,
  204. &tls.Config{
  205. Certificates: []tls.Certificate{*b.tlsConfig.BackendCertificate},
  206. ClientAuth: tls.RequireAndVerifyClientCert,
  207. ClientCAs: certificatePool,
  208. })
  209. mux := http.NewServeMux()
  210. handlerAdapter := func(
  211. w http.ResponseWriter,
  212. r *http.Request,
  213. handler func(bool, []byte) ([]byte, error)) (retErr error) {
  214. defer func() {
  215. if retErr != nil {
  216. logger.WithTrace().Warning(fmt.Sprintf("handler failed: %s\n", retErr))
  217. http.Error(w, retErr.Error(), http.StatusInternalServerError)
  218. }
  219. }()
  220. headerName := b.shim.ClientIPHeaderName()
  221. clientIPHeader, ok := r.Header[headerName]
  222. if !ok {
  223. return errors.Tracef("missing header: %s", headerName)
  224. }
  225. if len(clientIPHeader) != 1 ||
  226. (b.expectedClientIP != "" && clientIPHeader[0] != b.expectedClientIP) {
  227. return errors.Tracef("invalid header: %s", headerName)
  228. }
  229. headerName = b.shim.ClientGeoIPDataHeaderName()
  230. clientGeoIPDataHeader, ok := r.Header[headerName]
  231. if !ok {
  232. return errors.Tracef("missing header: %s", headerName)
  233. }
  234. var geoIPData common.GeoIPData
  235. if len(clientGeoIPDataHeader) != 1 ||
  236. json.Unmarshal([]byte(clientGeoIPDataHeader[0]), &geoIPData) != nil ||
  237. (b.expectedClientGeoIPData != nil && geoIPData != *b.expectedClientGeoIPData) {
  238. return errors.Tracef("invalid header: %s", headerName)
  239. }
  240. headerName = b.shim.ClientTunneledHeaderName()
  241. clientTunneledHeader, ok := r.Header[headerName]
  242. if !ok {
  243. return errors.Tracef("missing header: %s", headerName)
  244. }
  245. if len(clientTunneledHeader) != 1 ||
  246. !common.Contains([]string{"true", "false"}, clientTunneledHeader[0]) {
  247. return errors.Tracef("invalid header: %s", headerName)
  248. }
  249. tunneled := clientTunneledHeader[0] == "true"
  250. headerName = b.shim.HostIDHeaderName()
  251. hostIDHeader, ok := r.Header[headerName]
  252. if !ok {
  253. return errors.Tracef("missing header: %s", headerName)
  254. }
  255. if len(hostIDHeader) != 1 ||
  256. (b.expectedHostID != "" && hostIDHeader[0] != b.expectedHostID) {
  257. return errors.Tracef("invalid header: %s", headerName)
  258. }
  259. request, err := io.ReadAll(r.Body)
  260. if err != nil {
  261. return errors.Trace(err)
  262. }
  263. r.Body.Close()
  264. response, err := handler(tunneled, request)
  265. if err != nil {
  266. return errors.Trace(err)
  267. }
  268. _, err = w.Write(response)
  269. if err != nil {
  270. return errors.Trace(err)
  271. }
  272. return nil
  273. }
  274. mux.HandleFunc(b.shim.DiscoverServerEntriesRequestPath(),
  275. func(w http.ResponseWriter, r *http.Request) {
  276. _ = handlerAdapter(w, r, b.handleDiscoverServerEntries)
  277. })
  278. mux.HandleFunc(b.shim.GetServerEntriesRequestPath(),
  279. func(w http.ResponseWriter, r *http.Request) {
  280. _ = handlerAdapter(w, r, b.handleGetServerEntries)
  281. })
  282. mux.HandleFunc(b.shim.GetActiveOSLsRequestPath(),
  283. func(w http.ResponseWriter, r *http.Request) {
  284. _ = handlerAdapter(w, r, b.handleGetActiveOSLs)
  285. })
  286. mux.HandleFunc(b.shim.GetOSLFileSpecsRequestPath(),
  287. func(w http.ResponseWriter, r *http.Request) {
  288. _ = handlerAdapter(w, r, b.handleGetOSLFileSpecs)
  289. })
  290. server := &http.Server{
  291. Handler: mux,
  292. }
  293. go func() {
  294. _ = server.Serve(listener)
  295. }()
  296. b.listener = listener
  297. return nil
  298. }
  299. func (b *TestDSLBackend) Stop() {
  300. if b.listener == nil {
  301. return
  302. }
  303. _ = b.listener.Close()
  304. }
  305. func (b *TestDSLBackend) GetAddress() string {
  306. if b.listener == nil {
  307. return ""
  308. }
  309. return b.listener.Addr().String()
  310. }
  311. func (b *TestDSLBackend) GetServerEntryCount(isTunneled bool) int {
  312. if isTunneled {
  313. return len(b.tunneledServerEntries)
  314. }
  315. return len(b.untunneledServerEntries)
  316. }
  317. func (b *TestDSLBackend) SetServerEntries(
  318. isTunneled bool, encodedServerEntries []string) error {
  319. source := "DSL-untunneled"
  320. if isTunneled {
  321. source = "DSL-tunneled"
  322. }
  323. sourcedServerEntries := make(map[string]*dslSourcedServerEntry)
  324. for _, encodedServerEntry := range encodedServerEntries {
  325. serverEntryFields, err := protocol.DecodeServerEntryFields(
  326. encodedServerEntry, "", "")
  327. if err != nil {
  328. return errors.Trace(err)
  329. }
  330. packedServerEntryFields, err :=
  331. protocol.EncodePackedServerEntryFields(serverEntryFields)
  332. if err != nil {
  333. return errors.Trace(err)
  334. }
  335. sourcedServerEntries[serverEntryFields.GetTag()] = &dslSourcedServerEntry{
  336. ServerEntryFields: packedServerEntryFields,
  337. Source: source,
  338. }
  339. }
  340. if isTunneled {
  341. b.tunneledServerEntries = sourcedServerEntries
  342. } else {
  343. b.untunneledServerEntries = sourcedServerEntries
  344. }
  345. return nil
  346. }
  347. func (b *TestDSLBackend) SetOSLPaveData(oslPaveData []*osl.PaveData) {
  348. b.oslPaveData.Store(oslPaveData)
  349. }
  350. func (b *TestDSLBackend) handleDiscoverServerEntries(
  351. tunneled bool,
  352. cborRequest []byte) ([]byte, error) {
  353. serverEntries := b.untunneledServerEntries
  354. if tunneled {
  355. serverEntries = b.tunneledServerEntries
  356. }
  357. _, oslKeys, discoverCount, err :=
  358. b.shim.UnmarshalDiscoverServerEntriesRequest(cborRequest)
  359. if err != nil {
  360. return nil, errors.Trace(err)
  361. }
  362. missingOSLs := false
  363. oslPaveDataValue := b.oslPaveData.Load()
  364. if oslPaveDataValue != nil {
  365. oslPaveData := oslPaveDataValue.([]*osl.PaveData)
  366. // When b.oslPaveData is set, the client must provide the expected OSL
  367. // keys in order to discover any server entries.
  368. for _, oslPaveData := range oslPaveData {
  369. found := false
  370. for _, key := range oslKeys {
  371. if bytes.Equal(key, oslPaveData.FileKey) {
  372. found = true
  373. break
  374. }
  375. }
  376. if !found {
  377. missingOSLs = true
  378. break
  379. }
  380. }
  381. }
  382. var versionedServerEntryTags []*struct {
  383. Tag []byte
  384. Version int32
  385. }
  386. if !missingOSLs {
  387. count := 0
  388. for tag := range serverEntries {
  389. if count >= int(discoverCount) {
  390. break
  391. }
  392. count += 1
  393. // Test server entry tags are base64-encoded random byte strings.
  394. serverEntryTag, err := base64.StdEncoding.DecodeString(tag)
  395. if err != nil {
  396. return nil, errors.Trace(err)
  397. }
  398. versionedServerEntryTags = append(
  399. versionedServerEntryTags,
  400. &struct {
  401. Tag []byte
  402. Version int32
  403. }{serverEntryTag, 0})
  404. }
  405. }
  406. cborResponse, err := b.shim.MarshalDiscoverServerEntriesResponse(
  407. versionedServerEntryTags)
  408. if err != nil {
  409. return nil, errors.Trace(err)
  410. }
  411. return cborResponse, nil
  412. }
  413. func (b *TestDSLBackend) handleGetServerEntries(
  414. tunneled bool,
  415. cborRequest []byte) ([]byte, error) {
  416. serverEntries := b.untunneledServerEntries
  417. if tunneled {
  418. serverEntries = b.tunneledServerEntries
  419. }
  420. _, serverEntryTags, err :=
  421. b.shim.UnmarshalGetServerEntriesRequest(cborRequest)
  422. if err != nil {
  423. return nil, errors.Trace(err)
  424. }
  425. var sourcedServerEntryTags []*struct {
  426. ServerEntryFields protocol.PackedServerEntryFields
  427. Source string
  428. }
  429. for _, serverEntryTag := range serverEntryTags {
  430. tag := base64.StdEncoding.EncodeToString(serverEntryTag)
  431. sourcedServerEntry, ok := serverEntries[tag]
  432. if !ok {
  433. // An actual DSL backend must return empty slot in this case, as
  434. // the requested server entry could be pruned or unavailable. For
  435. // this test, this case is unexpected.
  436. return nil, errors.TraceNew("unknown server entry tag")
  437. }
  438. sourcedServerEntryTags = append(
  439. sourcedServerEntryTags, &struct {
  440. ServerEntryFields protocol.PackedServerEntryFields
  441. Source string
  442. }{sourcedServerEntry.ServerEntryFields, sourcedServerEntry.Source})
  443. }
  444. cborResponse, err := b.shim.MarshalGetServerEntriesResponse(
  445. sourcedServerEntryTags)
  446. if err != nil {
  447. return nil, errors.Trace(err)
  448. }
  449. return cborResponse, nil
  450. }
  451. func (b *TestDSLBackend) handleGetActiveOSLs(
  452. _ bool,
  453. cborRequest []byte) ([]byte, error) {
  454. _, err := b.shim.UnmarshalGetActiveOSLsRequest(cborRequest)
  455. if err != nil {
  456. return nil, errors.Trace(err)
  457. }
  458. var activeOSLIDs [][]byte
  459. oslPaveData := b.oslPaveData.Load().([]*osl.PaveData)
  460. for _, oslPaveData := range oslPaveData {
  461. activeOSLIDs = append(activeOSLIDs, oslPaveData.FileSpec.ID)
  462. }
  463. cborResponse, err := b.shim.MarshalGetActiveOSLsResponse(activeOSLIDs)
  464. if err != nil {
  465. return nil, errors.Trace(err)
  466. }
  467. return cborResponse, nil
  468. }
  469. func (b *TestDSLBackend) handleGetOSLFileSpecs(
  470. _ bool,
  471. cborRequest []byte) ([]byte, error) {
  472. _, oslIDs, err := b.shim.UnmarshalGetOSLFileSpecsRequest(cborRequest)
  473. if err != nil {
  474. return nil, errors.Trace(err)
  475. }
  476. var oslFileSpecs [][]byte
  477. oslPaveData := b.oslPaveData.Load().([]*osl.PaveData)
  478. for _, oslID := range oslIDs {
  479. var matchingPaveData *osl.PaveData
  480. for _, oslPaveData := range oslPaveData {
  481. if bytes.Equal(oslID, oslPaveData.FileSpec.ID) {
  482. matchingPaveData = oslPaveData
  483. break
  484. }
  485. }
  486. if matchingPaveData == nil {
  487. // An actual DSL backend must return empty slot in this case, as
  488. // the requested OSL may no longer be active. For this test, this
  489. // case is unexpected.
  490. return nil, errors.TraceNew("unknown OSL ID")
  491. }
  492. cborOSLFileSpec, err := protocol.CBOREncoding.Marshal(matchingPaveData.FileSpec)
  493. if err != nil {
  494. return nil, errors.Trace(err)
  495. }
  496. oslFileSpecs = append(oslFileSpecs, cborOSLFileSpec)
  497. }
  498. cborResponse, err := b.shim.MarshalGetOSLFileSpecsResponse(oslFileSpecs)
  499. if err != nil {
  500. return nil, errors.Trace(err)
  501. }
  502. return cborResponse, nil
  503. }
  504. func InitializeTestOSLPaveData() ([]*osl.PaveData, []*osl.PaveData, []*osl.SLOK, error) {
  505. // Adapted from testObfuscatedRemoteServerLists in psiphon/remoteServerList_test.go
  506. oslConfigJSONTemplate := `
  507. {
  508. "Schemes" : [
  509. {
  510. "Epoch" : "%s",
  511. "PaveDataOSLCount" : 2,
  512. "Regions" : [],
  513. "PropagationChannelIDs" : ["%s"],
  514. "MasterKey" : "vwab2WY3eNyMBpyFVPtsivMxF4MOpNHM/T7rHJIXctg=",
  515. "SeedSpecs" : [
  516. {
  517. "ID" : "KuP2V6gLcROIFzb/27fUVu4SxtEfm2omUoISlrWv1mA=",
  518. "UpstreamSubnets" : ["0.0.0.0/0"],
  519. "Targets" :
  520. {
  521. "BytesRead" : 1,
  522. "BytesWritten" : 1,
  523. "PortForwardDurationNanoseconds" : 1
  524. }
  525. }
  526. ],
  527. "SeedSpecThreshold" : 1,
  528. "SeedPeriodNanoseconds" : %d,
  529. "SeedPeriodKeySplits": [
  530. {
  531. "Total": 1,
  532. "Threshold": 1
  533. }
  534. ]
  535. }
  536. ]
  537. }`
  538. now := time.Now().UTC()
  539. seedPeriod := 1 * time.Second
  540. epoch := now.Truncate(seedPeriod)
  541. epochStr := epoch.Format(time.RFC3339Nano)
  542. propagationChannelID := prng.HexString(8)
  543. oslConfigJSON := fmt.Sprintf(
  544. oslConfigJSONTemplate,
  545. epochStr,
  546. propagationChannelID,
  547. seedPeriod)
  548. oslConfig, err := osl.LoadConfig([]byte(oslConfigJSON))
  549. if err != nil {
  550. return nil, nil, nil, errors.Trace(err)
  551. }
  552. oslPaveData, err := oslConfig.GetPaveData(0)
  553. if err != nil {
  554. return nil, nil, nil, errors.Trace(err)
  555. }
  556. backendPaveData1, ok := oslPaveData[propagationChannelID]
  557. if !ok {
  558. return nil, nil, nil, errors.TraceNew("unexpected missing OSL file data")
  559. }
  560. // Mock seeding SLOKs
  561. //
  562. // Normally, clients supplying the specified propagation channel ID would
  563. // receive SLOKs via the psiphond tunnel connection
  564. seedState := oslConfig.NewClientSeedState("", propagationChannelID, nil)
  565. seedPortForward := seedState.NewClientSeedPortForward(net.ParseIP("0.0.0.0"), nil)
  566. seedPortForward.UpdateProgress(1, 1, 1)
  567. payload := seedState.GetSeedPayload()
  568. if len(payload.SLOKs) != 1 {
  569. return nil, nil, nil, errors.Tracef("unexpected SLOK count %d", len(payload.SLOKs))
  570. }
  571. clientSLOKs := payload.SLOKs
  572. // Rollover to the next OSL time period and generate a new set of active
  573. // OSLs and SLOKs.
  574. time.Sleep(2 * seedPeriod)
  575. oslPaveData, err = oslConfig.GetPaveData(0)
  576. if err != nil {
  577. return nil, nil, nil, errors.Trace(err)
  578. }
  579. backendPaveData2, ok := oslPaveData[propagationChannelID]
  580. if !ok {
  581. return nil, nil, nil, errors.TraceNew("unexpected missing OSL file data")
  582. }
  583. seedState = oslConfig.NewClientSeedState("", propagationChannelID, nil)
  584. seedPortForward = seedState.NewClientSeedPortForward(net.ParseIP("0.0.0.0"), nil)
  585. seedPortForward.UpdateProgress(1, 1, 1)
  586. payload = seedState.GetSeedPayload()
  587. if len(payload.SLOKs) != 1 {
  588. return nil, nil, nil, errors.Tracef("unexpected SLOK count %d", len(payload.SLOKs))
  589. }
  590. clientSLOKs = append(clientSLOKs, payload.SLOKs...)
  591. // Double check that PaveData periods don't overlap.
  592. for _, paveData1 := range backendPaveData1 {
  593. for _, paveData2 := range backendPaveData2 {
  594. if bytes.Equal(paveData1.FileSpec.ID, paveData2.FileSpec.ID) {
  595. return nil, nil, nil, errors.TraceNew("unexpected pave data overlap")
  596. }
  597. }
  598. }
  599. return backendPaveData1, backendPaveData2, clientSLOKs, nil
  600. }
  601. type TestDSLTLSConfig struct {
  602. CACertificate *x509.Certificate
  603. CACertificatePEM []byte
  604. BackendCertificate *tls.Certificate
  605. BackendCertificatePEM []byte
  606. BackendKeyPEM []byte
  607. RelayCertificate *tls.Certificate
  608. RelayCertificatePEM []byte
  609. RelayKeyPEM []byte
  610. }
  611. func NewTestDSLTLSConfig() (*TestDSLTLSConfig, error) {
  612. CAPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  613. if err != nil {
  614. return nil, errors.Trace(err)
  615. }
  616. now := time.Now()
  617. template := &x509.Certificate{
  618. SerialNumber: big.NewInt(1),
  619. Subject: pkix.Name{
  620. Organization: []string{"test root CA"},
  621. },
  622. NotBefore: now,
  623. NotAfter: now.AddDate(0, 0, 1),
  624. IsCA: true,
  625. BasicConstraintsValid: true,
  626. KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
  627. }
  628. CACertificateDER, err := x509.CreateCertificate(
  629. rand.Reader, template, template, &CAPrivateKey.PublicKey, CAPrivateKey)
  630. if err != nil {
  631. return nil, errors.Trace(err)
  632. }
  633. CACertificatePEM := pem.EncodeToMemory(
  634. &pem.Block{Type: "CERTIFICATE", Bytes: CACertificateDER})
  635. if err != nil {
  636. return nil, errors.Trace(err)
  637. }
  638. CACertificate, err := x509.ParseCertificate(CACertificateDER)
  639. if err != nil {
  640. return nil, errors.Trace(err)
  641. }
  642. issueCertificate := func(
  643. name string, isServer bool) (
  644. *tls.Certificate, []byte, []byte, error) {
  645. privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  646. if err != nil {
  647. return nil, nil, nil, errors.Trace(err)
  648. }
  649. now := time.Now()
  650. template := &x509.Certificate{
  651. SerialNumber: big.NewInt(time.Now().UnixNano()),
  652. Subject: pkix.Name{
  653. CommonName: name,
  654. },
  655. NotBefore: now,
  656. NotAfter: now.AddDate(0, 0, 1),
  657. KeyUsage: x509.KeyUsageDigitalSignature,
  658. }
  659. if isServer {
  660. template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")}
  661. template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
  662. } else {
  663. template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
  664. }
  665. certificateDER, err := x509.CreateCertificate(
  666. rand.Reader, template, CACertificate, &privateKey.PublicKey, CAPrivateKey)
  667. if err != nil {
  668. return nil, nil, nil, errors.Trace(err)
  669. }
  670. certPEM := pem.EncodeToMemory(
  671. &pem.Block{Type: "CERTIFICATE", Bytes: certificateDER})
  672. keyPEM := pem.EncodeToMemory(
  673. &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
  674. tlsCertificate, err := tls.X509KeyPair(certPEM, keyPEM)
  675. if err != nil {
  676. return nil, nil, nil, errors.Trace(err)
  677. }
  678. return &tlsCertificate, certPEM, keyPEM, nil
  679. }
  680. backendCertificate, backendCertificatePEM, backendKeyPEM, err :=
  681. issueCertificate("backend", true)
  682. if err != nil {
  683. return nil, errors.Trace(err)
  684. }
  685. relayCertificate, relayCertificatePEM, relayKeyPEM, err :=
  686. issueCertificate("relay", false)
  687. if err != nil {
  688. return nil, errors.Trace(err)
  689. }
  690. return &TestDSLTLSConfig{
  691. CACertificate: CACertificate,
  692. CACertificatePEM: CACertificatePEM,
  693. BackendCertificate: backendCertificate,
  694. BackendCertificatePEM: backendCertificatePEM,
  695. BackendKeyPEM: backendKeyPEM,
  696. RelayCertificate: relayCertificate,
  697. RelayCertificatePEM: relayCertificatePEM,
  698. RelayKeyPEM: relayKeyPEM,
  699. }, nil
  700. }
  701. func (config *TestDSLTLSConfig) WriteRelayFiles(dirName string) (
  702. string, string, string, error) {
  703. caCertificatesFilename := filepath.Join(
  704. dirName, "dslRelayCACert.pem")
  705. err := os.WriteFile(
  706. caCertificatesFilename,
  707. config.CACertificatePEM,
  708. 0644)
  709. if err != nil {
  710. return "", "", "", errors.Trace(err)
  711. }
  712. hostCertificateFilename := filepath.Join(
  713. dirName, "dslRelayHostCert.pem")
  714. err = os.WriteFile(
  715. hostCertificateFilename,
  716. config.RelayCertificatePEM,
  717. 0644)
  718. if err != nil {
  719. return "", "", "", errors.Trace(err)
  720. }
  721. hostKeyFilename := filepath.Join(
  722. dirName, "dslRelayHostKey.pem")
  723. err = os.WriteFile(
  724. hostKeyFilename,
  725. config.RelayKeyPEM,
  726. 0644)
  727. if err != nil {
  728. return "", "", "", errors.Trace(err)
  729. }
  730. return caCertificatesFilename,
  731. hostCertificateFilename,
  732. hostKeyFilename,
  733. nil
  734. }