dsl.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. /*
  2. * Copyright (c) 2025, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package testutils
  20. import (
  21. "bytes"
  22. "crypto/rand"
  23. "crypto/rsa"
  24. "crypto/tls"
  25. "crypto/x509"
  26. "crypto/x509/pkix"
  27. "encoding/base64"
  28. "encoding/json"
  29. "encoding/pem"
  30. "fmt"
  31. "io"
  32. "math/big"
  33. "net"
  34. "net/http"
  35. "os"
  36. "path/filepath"
  37. "strings"
  38. "sync"
  39. "sync/atomic"
  40. "time"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
  44. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  45. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  46. )
  47. type DSLBackendTestShim interface {
  48. ClientIPHeaderName() string
  49. ClientGeoIPDataHeaderName() string
  50. ClientTunneledHeaderName() string
  51. HostIDHeaderName() string
  52. DiscoverServerEntriesRequestPath() string
  53. GetServerEntriesRequestPath() string
  54. GetActiveOSLsRequestPath() string
  55. GetOSLFileSpecsRequestPath() string
  56. UnmarshalDiscoverServerEntriesRequest(
  57. cborRequest []byte) (
  58. apiParams protocol.PackedAPIParameters,
  59. oslKeys [][]byte,
  60. discoverCount int32,
  61. retErr error)
  62. MarshalDiscoverServerEntriesResponse(
  63. versionedServerEntryTags []*struct {
  64. Tag []byte
  65. Version int32
  66. PrioritizeDial bool
  67. }) (
  68. cborResponse []byte,
  69. retErr error)
  70. UnmarshalGetServerEntriesRequest(
  71. cborRequest []byte) (
  72. apiParams protocol.PackedAPIParameters,
  73. serverEntryTags [][]byte,
  74. retErr error)
  75. MarshalGetServerEntriesResponse(
  76. sourcedServerEntries []*struct {
  77. ServerEntryFields protocol.PackedServerEntryFields
  78. Source string
  79. }) (
  80. cborResponse []byte,
  81. retErr error)
  82. UnmarshalGetActiveOSLsRequest(
  83. cborRequest []byte) (
  84. apiParams protocol.PackedAPIParameters,
  85. retErr error)
  86. MarshalGetActiveOSLsResponse(
  87. activeOSLIDs [][]byte) (
  88. cborResponse []byte,
  89. retErr error)
  90. UnmarshalGetOSLFileSpecsRequest(
  91. cborRequest []byte) (
  92. apiParams protocol.PackedAPIParameters,
  93. oslIDs [][]byte,
  94. retErr error)
  95. MarshalGetOSLFileSpecsResponse(
  96. oslFileSpecs [][]byte) (
  97. cborResponse []byte,
  98. retErr error)
  99. }
  100. // TestDSLBackend is a mock DSL backend intended only for testing.
  101. type TestDSLBackend struct {
  102. shim DSLBackendTestShim
  103. tlsConfig *TestDSLTLSConfig
  104. expectedClientIP string
  105. expectedClientGeoIPData *common.GeoIPData
  106. expectedHostID string
  107. oslPaveData atomic.Value
  108. untunneledServerEntries map[string]*dslSourcedServerEntry
  109. tunneledServerEntries map[string]*dslSourcedServerEntry
  110. listener net.Listener
  111. }
  112. type dslSourcedServerEntry struct {
  113. ServerEntryFields protocol.PackedServerEntryFields
  114. Source string
  115. PrioritizeDial bool
  116. }
  117. func NewTestDSLBackend(
  118. shim DSLBackendTestShim,
  119. tlsConfig *TestDSLTLSConfig,
  120. expectedClientIP string,
  121. expectedClientGeoIPData *common.GeoIPData,
  122. expectedHostID string,
  123. oslPaveData []*osl.PaveData) (*TestDSLBackend, error) {
  124. b := &TestDSLBackend{
  125. shim: shim,
  126. tlsConfig: tlsConfig,
  127. expectedClientIP: expectedClientIP,
  128. expectedClientGeoIPData: expectedClientGeoIPData,
  129. expectedHostID: expectedHostID,
  130. }
  131. b.oslPaveData.Store(oslPaveData)
  132. // Generate mock server entries.
  133. // Run GenerateConfig concurrently to try to take advantage of multiple
  134. // CPU cores.
  135. //
  136. // Update: no longer using server.GenerateConfig due to import cycle.
  137. var initMutex sync.Mutex
  138. var initGroup sync.WaitGroup
  139. var initErr error
  140. serverEntries := make(map[string]*dslSourcedServerEntry)
  141. for i := 1; i <= 128; i++ {
  142. initGroup.Add(1)
  143. go func(i int) (retErr error) {
  144. defer initGroup.Done()
  145. defer func() {
  146. if retErr != nil {
  147. initMutex.Lock()
  148. initErr = retErr
  149. initMutex.Unlock()
  150. }
  151. }()
  152. serverEntry := &protocol.ServerEntry{
  153. Tag: prng.Base64String(32),
  154. IpAddress: fmt.Sprintf("192.0.2.%d", i),
  155. SshUsername: prng.HexString(8),
  156. SshPassword: prng.HexString(32),
  157. SshHostKey: prng.Base64String(280),
  158. SshObfuscatedPort: prng.Range(1, 65535),
  159. SshObfuscatedKey: prng.HexString(32),
  160. Capabilities: []string{"OSSH"},
  161. Region: prng.HexString(1),
  162. ProviderID: strings.ToUpper(prng.HexString(8)),
  163. ConfigurationVersion: 0,
  164. Signature: prng.Base64String(80),
  165. }
  166. serverEntryFields, err := serverEntry.GetServerEntryFields()
  167. if err != nil {
  168. return errors.Trace(err)
  169. }
  170. packed, err := protocol.EncodePackedServerEntryFields(serverEntryFields)
  171. if err != nil {
  172. return errors.Trace(err)
  173. }
  174. source := fmt.Sprintf("DSL-compartment-%d", i)
  175. initMutex.Lock()
  176. if serverEntries[serverEntry.Tag] != nil {
  177. initMutex.Unlock()
  178. return errors.TraceNew("duplicate tag")
  179. }
  180. serverEntries[serverEntry.Tag] = &dslSourcedServerEntry{
  181. ServerEntryFields: packed,
  182. Source: source,
  183. PrioritizeDial: prng.FlipCoin(),
  184. }
  185. initMutex.Unlock()
  186. return nil
  187. }(i)
  188. }
  189. initGroup.Wait()
  190. if initErr != nil {
  191. return nil, errors.Trace(initErr)
  192. }
  193. b.untunneledServerEntries = serverEntries
  194. b.tunneledServerEntries = serverEntries
  195. return b, nil
  196. }
  197. func (b *TestDSLBackend) Start() error {
  198. logger := NewTestLoggerWithComponent("backend")
  199. listener, err := net.Listen("tcp", "127.0.0.1:0")
  200. if err != nil {
  201. return errors.Trace(err)
  202. }
  203. certificatePool := x509.NewCertPool()
  204. certificatePool.AddCert(b.tlsConfig.CACertificate)
  205. listener = tls.NewListener(
  206. listener,
  207. &tls.Config{
  208. Certificates: []tls.Certificate{*b.tlsConfig.BackendCertificate},
  209. ClientAuth: tls.RequireAndVerifyClientCert,
  210. ClientCAs: certificatePool,
  211. })
  212. mux := http.NewServeMux()
  213. handlerAdapter := func(
  214. w http.ResponseWriter,
  215. r *http.Request,
  216. handler func(bool, []byte) ([]byte, error)) (retErr error) {
  217. defer func() {
  218. if retErr != nil {
  219. logger.WithTrace().Warning(fmt.Sprintf("handler failed: %s\n", retErr))
  220. http.Error(w, retErr.Error(), http.StatusInternalServerError)
  221. }
  222. }()
  223. headerName := b.shim.ClientIPHeaderName()
  224. clientIPHeader, ok := r.Header[headerName]
  225. if !ok {
  226. return errors.Tracef("missing header: %s", headerName)
  227. }
  228. if len(clientIPHeader) != 1 ||
  229. (b.expectedClientIP != "" && clientIPHeader[0] != b.expectedClientIP) {
  230. return errors.Tracef("invalid header: %s", headerName)
  231. }
  232. headerName = b.shim.ClientGeoIPDataHeaderName()
  233. clientGeoIPDataHeader, ok := r.Header[headerName]
  234. if !ok {
  235. return errors.Tracef("missing header: %s", headerName)
  236. }
  237. var geoIPData common.GeoIPData
  238. if len(clientGeoIPDataHeader) != 1 ||
  239. json.Unmarshal([]byte(clientGeoIPDataHeader[0]), &geoIPData) != nil ||
  240. (b.expectedClientGeoIPData != nil && geoIPData != *b.expectedClientGeoIPData) {
  241. return errors.Tracef("invalid header: %s", headerName)
  242. }
  243. headerName = b.shim.ClientTunneledHeaderName()
  244. clientTunneledHeader, ok := r.Header[headerName]
  245. if !ok {
  246. return errors.Tracef("missing header: %s", headerName)
  247. }
  248. if len(clientTunneledHeader) != 1 ||
  249. !common.Contains([]string{"true", "false"}, clientTunneledHeader[0]) {
  250. return errors.Tracef("invalid header: %s", headerName)
  251. }
  252. tunneled := clientTunneledHeader[0] == "true"
  253. headerName = b.shim.HostIDHeaderName()
  254. hostIDHeader, ok := r.Header[headerName]
  255. if !ok {
  256. return errors.Tracef("missing header: %s", headerName)
  257. }
  258. if len(hostIDHeader) != 1 ||
  259. (b.expectedHostID != "" && hostIDHeader[0] != b.expectedHostID) {
  260. return errors.Tracef("invalid header: %s", headerName)
  261. }
  262. request, err := io.ReadAll(r.Body)
  263. if err != nil {
  264. return errors.Trace(err)
  265. }
  266. r.Body.Close()
  267. response, err := handler(tunneled, request)
  268. if err != nil {
  269. return errors.Trace(err)
  270. }
  271. _, err = w.Write(response)
  272. if err != nil {
  273. return errors.Trace(err)
  274. }
  275. return nil
  276. }
  277. mux.HandleFunc(b.shim.DiscoverServerEntriesRequestPath(),
  278. func(w http.ResponseWriter, r *http.Request) {
  279. _ = handlerAdapter(w, r, b.handleDiscoverServerEntries)
  280. })
  281. mux.HandleFunc(b.shim.GetServerEntriesRequestPath(),
  282. func(w http.ResponseWriter, r *http.Request) {
  283. _ = handlerAdapter(w, r, b.handleGetServerEntries)
  284. })
  285. mux.HandleFunc(b.shim.GetActiveOSLsRequestPath(),
  286. func(w http.ResponseWriter, r *http.Request) {
  287. _ = handlerAdapter(w, r, b.handleGetActiveOSLs)
  288. })
  289. mux.HandleFunc(b.shim.GetOSLFileSpecsRequestPath(),
  290. func(w http.ResponseWriter, r *http.Request) {
  291. _ = handlerAdapter(w, r, b.handleGetOSLFileSpecs)
  292. })
  293. server := &http.Server{
  294. Handler: mux,
  295. }
  296. go func() {
  297. _ = server.Serve(listener)
  298. }()
  299. b.listener = listener
  300. return nil
  301. }
  302. func (b *TestDSLBackend) Stop() {
  303. if b.listener == nil {
  304. return
  305. }
  306. _ = b.listener.Close()
  307. }
  308. func (b *TestDSLBackend) GetAddress() string {
  309. if b.listener == nil {
  310. return ""
  311. }
  312. return b.listener.Addr().String()
  313. }
  314. func (b *TestDSLBackend) GetServerEntryCount(isTunneled bool) int {
  315. if isTunneled {
  316. return len(b.tunneledServerEntries)
  317. }
  318. return len(b.untunneledServerEntries)
  319. }
  320. func (b *TestDSLBackend) GetServerEntryProperties(
  321. serverEntryTag string) (string, bool, error) {
  322. entry, ok := b.untunneledServerEntries[serverEntryTag]
  323. if !ok {
  324. entry, ok = b.tunneledServerEntries[serverEntryTag]
  325. if !ok {
  326. return "", false, errors.TraceNew("unknown server entry tag")
  327. }
  328. }
  329. return entry.Source, entry.PrioritizeDial, nil
  330. }
  331. func (b *TestDSLBackend) SetServerEntries(
  332. isTunneled bool,
  333. prioritizeDial bool,
  334. encodedServerEntries []string) error {
  335. source := "DSL-untunneled"
  336. if isTunneled {
  337. source = "DSL-tunneled"
  338. }
  339. sourcedServerEntries := make(map[string]*dslSourcedServerEntry)
  340. for _, encodedServerEntry := range encodedServerEntries {
  341. serverEntryFields, err := protocol.DecodeServerEntryFields(
  342. encodedServerEntry, "", "")
  343. if err != nil {
  344. return errors.Trace(err)
  345. }
  346. packedServerEntryFields, err :=
  347. protocol.EncodePackedServerEntryFields(serverEntryFields)
  348. if err != nil {
  349. return errors.Trace(err)
  350. }
  351. sourcedServerEntries[serverEntryFields.GetTag()] = &dslSourcedServerEntry{
  352. ServerEntryFields: packedServerEntryFields,
  353. Source: source,
  354. PrioritizeDial: prioritizeDial,
  355. }
  356. }
  357. if isTunneled {
  358. b.tunneledServerEntries = sourcedServerEntries
  359. } else {
  360. b.untunneledServerEntries = sourcedServerEntries
  361. }
  362. return nil
  363. }
  364. func (b *TestDSLBackend) SetOSLPaveData(oslPaveData []*osl.PaveData) {
  365. b.oslPaveData.Store(oslPaveData)
  366. }
  367. func (b *TestDSLBackend) handleDiscoverServerEntries(
  368. tunneled bool,
  369. cborRequest []byte) ([]byte, error) {
  370. serverEntries := b.untunneledServerEntries
  371. if tunneled {
  372. serverEntries = b.tunneledServerEntries
  373. }
  374. _, oslKeys, discoverCount, err :=
  375. b.shim.UnmarshalDiscoverServerEntriesRequest(cborRequest)
  376. if err != nil {
  377. return nil, errors.Trace(err)
  378. }
  379. missingOSLs := false
  380. oslPaveDataValue := b.oslPaveData.Load()
  381. if oslPaveDataValue != nil {
  382. oslPaveData := oslPaveDataValue.([]*osl.PaveData)
  383. // When b.oslPaveData is set, the client must provide the expected OSL
  384. // keys in order to discover any server entries.
  385. for _, oslPaveData := range oslPaveData {
  386. found := false
  387. for _, key := range oslKeys {
  388. if bytes.Equal(key, oslPaveData.FileKey) {
  389. found = true
  390. break
  391. }
  392. }
  393. if !found {
  394. missingOSLs = true
  395. break
  396. }
  397. }
  398. }
  399. var versionedServerEntryTags []*struct {
  400. Tag []byte
  401. Version int32
  402. PrioritizeDial bool
  403. }
  404. if !missingOSLs {
  405. count := 0
  406. for tag, sourcedServerEntry := range serverEntries {
  407. if count >= int(discoverCount) {
  408. break
  409. }
  410. count += 1
  411. // Test server entry tags are base64-encoded random byte strings.
  412. serverEntryTag, err := base64.StdEncoding.DecodeString(tag)
  413. if err != nil {
  414. return nil, errors.Trace(err)
  415. }
  416. versionedServerEntryTags = append(
  417. versionedServerEntryTags,
  418. &struct {
  419. Tag []byte
  420. Version int32
  421. PrioritizeDial bool
  422. }{serverEntryTag, 0, sourcedServerEntry.PrioritizeDial})
  423. }
  424. }
  425. cborResponse, err := b.shim.MarshalDiscoverServerEntriesResponse(
  426. versionedServerEntryTags)
  427. if err != nil {
  428. return nil, errors.Trace(err)
  429. }
  430. return cborResponse, nil
  431. }
  432. func (b *TestDSLBackend) handleGetServerEntries(
  433. tunneled bool,
  434. cborRequest []byte) ([]byte, error) {
  435. serverEntries := b.untunneledServerEntries
  436. if tunneled {
  437. serverEntries = b.tunneledServerEntries
  438. }
  439. _, serverEntryTags, err :=
  440. b.shim.UnmarshalGetServerEntriesRequest(cborRequest)
  441. if err != nil {
  442. return nil, errors.Trace(err)
  443. }
  444. var sourcedServerEntryTags []*struct {
  445. ServerEntryFields protocol.PackedServerEntryFields
  446. Source string
  447. }
  448. for _, serverEntryTag := range serverEntryTags {
  449. tag := base64.StdEncoding.EncodeToString(serverEntryTag)
  450. sourcedServerEntry, ok := serverEntries[tag]
  451. if !ok {
  452. // An actual DSL backend must return empty slot in this case, as
  453. // the requested server entry could be pruned or unavailable. For
  454. // this test, this case is unexpected.
  455. return nil, errors.TraceNew("unknown server entry tag")
  456. }
  457. sourcedServerEntryTags = append(
  458. sourcedServerEntryTags, &struct {
  459. ServerEntryFields protocol.PackedServerEntryFields
  460. Source string
  461. }{sourcedServerEntry.ServerEntryFields, sourcedServerEntry.Source})
  462. }
  463. cborResponse, err := b.shim.MarshalGetServerEntriesResponse(
  464. sourcedServerEntryTags)
  465. if err != nil {
  466. return nil, errors.Trace(err)
  467. }
  468. return cborResponse, nil
  469. }
  470. func (b *TestDSLBackend) handleGetActiveOSLs(
  471. _ bool,
  472. cborRequest []byte) ([]byte, error) {
  473. _, err := b.shim.UnmarshalGetActiveOSLsRequest(cborRequest)
  474. if err != nil {
  475. return nil, errors.Trace(err)
  476. }
  477. var activeOSLIDs [][]byte
  478. oslPaveData := b.oslPaveData.Load().([]*osl.PaveData)
  479. for _, oslPaveData := range oslPaveData {
  480. activeOSLIDs = append(activeOSLIDs, oslPaveData.FileSpec.ID)
  481. }
  482. cborResponse, err := b.shim.MarshalGetActiveOSLsResponse(activeOSLIDs)
  483. if err != nil {
  484. return nil, errors.Trace(err)
  485. }
  486. return cborResponse, nil
  487. }
  488. func (b *TestDSLBackend) handleGetOSLFileSpecs(
  489. _ bool,
  490. cborRequest []byte) ([]byte, error) {
  491. _, oslIDs, err := b.shim.UnmarshalGetOSLFileSpecsRequest(cborRequest)
  492. if err != nil {
  493. return nil, errors.Trace(err)
  494. }
  495. var oslFileSpecs [][]byte
  496. oslPaveData := b.oslPaveData.Load().([]*osl.PaveData)
  497. for _, oslID := range oslIDs {
  498. var matchingPaveData *osl.PaveData
  499. for _, oslPaveData := range oslPaveData {
  500. if bytes.Equal(oslID, oslPaveData.FileSpec.ID) {
  501. matchingPaveData = oslPaveData
  502. break
  503. }
  504. }
  505. if matchingPaveData == nil {
  506. // An actual DSL backend must return empty slot in this case, as
  507. // the requested OSL may no longer be active. For this test, this
  508. // case is unexpected.
  509. return nil, errors.TraceNew("unknown OSL ID")
  510. }
  511. cborOSLFileSpec, err := protocol.CBOREncoding.Marshal(matchingPaveData.FileSpec)
  512. if err != nil {
  513. return nil, errors.Trace(err)
  514. }
  515. oslFileSpecs = append(oslFileSpecs, cborOSLFileSpec)
  516. }
  517. cborResponse, err := b.shim.MarshalGetOSLFileSpecsResponse(oslFileSpecs)
  518. if err != nil {
  519. return nil, errors.Trace(err)
  520. }
  521. return cborResponse, nil
  522. }
  523. func InitializeTestOSLPaveData() ([]*osl.PaveData, []*osl.PaveData, []*osl.SLOK, error) {
  524. // Adapted from testObfuscatedRemoteServerLists in psiphon/remoteServerList_test.go
  525. oslConfigJSONTemplate := `
  526. {
  527. "Schemes" : [
  528. {
  529. "Epoch" : "%s",
  530. "PaveDataOSLCount" : 2,
  531. "Regions" : [],
  532. "PropagationChannelIDs" : ["%s"],
  533. "MasterKey" : "vwab2WY3eNyMBpyFVPtsivMxF4MOpNHM/T7rHJIXctg=",
  534. "SeedSpecs" : [
  535. {
  536. "ID" : "KuP2V6gLcROIFzb/27fUVu4SxtEfm2omUoISlrWv1mA=",
  537. "UpstreamSubnets" : ["0.0.0.0/0"],
  538. "Targets" :
  539. {
  540. "BytesRead" : 1,
  541. "BytesWritten" : 1,
  542. "PortForwardDurationNanoseconds" : 1
  543. }
  544. }
  545. ],
  546. "SeedSpecThreshold" : 1,
  547. "SeedPeriodNanoseconds" : %d,
  548. "SeedPeriodKeySplits": [
  549. {
  550. "Total": 1,
  551. "Threshold": 1
  552. }
  553. ]
  554. }
  555. ]
  556. }`
  557. now := time.Now().UTC()
  558. seedPeriod := 1 * time.Second
  559. epoch := now.Truncate(seedPeriod)
  560. epochStr := epoch.Format(time.RFC3339Nano)
  561. propagationChannelID := prng.HexString(8)
  562. oslConfigJSON := fmt.Sprintf(
  563. oslConfigJSONTemplate,
  564. epochStr,
  565. propagationChannelID,
  566. seedPeriod)
  567. oslConfig, err := osl.LoadConfig([]byte(oslConfigJSON))
  568. if err != nil {
  569. return nil, nil, nil, errors.Trace(err)
  570. }
  571. oslPaveData, err := oslConfig.GetPaveData(0)
  572. if err != nil {
  573. return nil, nil, nil, errors.Trace(err)
  574. }
  575. backendPaveData1, ok := oslPaveData[propagationChannelID]
  576. if !ok {
  577. return nil, nil, nil, errors.TraceNew("unexpected missing OSL file data")
  578. }
  579. // Mock seeding SLOKs
  580. //
  581. // Normally, clients supplying the specified propagation channel ID would
  582. // receive SLOKs via the psiphond tunnel connection
  583. seedState := oslConfig.NewClientSeedState("", propagationChannelID, nil)
  584. seedPortForward := seedState.NewClientSeedPortForward(net.ParseIP("0.0.0.0"), nil)
  585. seedPortForward.UpdateProgress(1, 1, 1)
  586. payload := seedState.GetSeedPayload()
  587. if len(payload.SLOKs) != 1 {
  588. return nil, nil, nil, errors.Tracef("unexpected SLOK count %d", len(payload.SLOKs))
  589. }
  590. clientSLOKs := payload.SLOKs
  591. // Rollover to the next OSL time period and generate a new set of active
  592. // OSLs and SLOKs.
  593. time.Sleep(2 * seedPeriod)
  594. oslPaveData, err = oslConfig.GetPaveData(0)
  595. if err != nil {
  596. return nil, nil, nil, errors.Trace(err)
  597. }
  598. backendPaveData2, ok := oslPaveData[propagationChannelID]
  599. if !ok {
  600. return nil, nil, nil, errors.TraceNew("unexpected missing OSL file data")
  601. }
  602. seedState = oslConfig.NewClientSeedState("", propagationChannelID, nil)
  603. seedPortForward = seedState.NewClientSeedPortForward(net.ParseIP("0.0.0.0"), nil)
  604. seedPortForward.UpdateProgress(1, 1, 1)
  605. payload = seedState.GetSeedPayload()
  606. if len(payload.SLOKs) != 1 {
  607. return nil, nil, nil, errors.Tracef("unexpected SLOK count %d", len(payload.SLOKs))
  608. }
  609. clientSLOKs = append(clientSLOKs, payload.SLOKs...)
  610. // Double check that PaveData periods don't overlap.
  611. for _, paveData1 := range backendPaveData1 {
  612. for _, paveData2 := range backendPaveData2 {
  613. if bytes.Equal(paveData1.FileSpec.ID, paveData2.FileSpec.ID) {
  614. return nil, nil, nil, errors.TraceNew("unexpected pave data overlap")
  615. }
  616. }
  617. }
  618. return backendPaveData1, backendPaveData2, clientSLOKs, nil
  619. }
  620. type TestDSLTLSConfig struct {
  621. CACertificate *x509.Certificate
  622. CACertificatePEM []byte
  623. BackendCertificate *tls.Certificate
  624. BackendCertificatePEM []byte
  625. BackendKeyPEM []byte
  626. RelayCertificate *tls.Certificate
  627. RelayCertificatePEM []byte
  628. RelayKeyPEM []byte
  629. }
  630. func NewTestDSLTLSConfig() (*TestDSLTLSConfig, error) {
  631. CAPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  632. if err != nil {
  633. return nil, errors.Trace(err)
  634. }
  635. now := time.Now()
  636. template := &x509.Certificate{
  637. SerialNumber: big.NewInt(1),
  638. Subject: pkix.Name{
  639. Organization: []string{"test root CA"},
  640. },
  641. NotBefore: now,
  642. NotAfter: now.AddDate(0, 0, 1),
  643. IsCA: true,
  644. BasicConstraintsValid: true,
  645. KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
  646. }
  647. CACertificateDER, err := x509.CreateCertificate(
  648. rand.Reader, template, template, &CAPrivateKey.PublicKey, CAPrivateKey)
  649. if err != nil {
  650. return nil, errors.Trace(err)
  651. }
  652. CACertificatePEM := pem.EncodeToMemory(
  653. &pem.Block{Type: "CERTIFICATE", Bytes: CACertificateDER})
  654. CACertificate, err := x509.ParseCertificate(CACertificateDER)
  655. if err != nil {
  656. return nil, errors.Trace(err)
  657. }
  658. issueCertificate := func(
  659. name string, isServer bool) (
  660. *tls.Certificate, []byte, []byte, error) {
  661. privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  662. if err != nil {
  663. return nil, nil, nil, errors.Trace(err)
  664. }
  665. now := time.Now()
  666. template := &x509.Certificate{
  667. SerialNumber: big.NewInt(time.Now().UnixNano()),
  668. Subject: pkix.Name{
  669. CommonName: name,
  670. },
  671. NotBefore: now,
  672. NotAfter: now.AddDate(0, 0, 1),
  673. KeyUsage: x509.KeyUsageDigitalSignature,
  674. }
  675. if isServer {
  676. template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")}
  677. template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
  678. } else {
  679. template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
  680. }
  681. certificateDER, err := x509.CreateCertificate(
  682. rand.Reader, template, CACertificate, &privateKey.PublicKey, CAPrivateKey)
  683. if err != nil {
  684. return nil, nil, nil, errors.Trace(err)
  685. }
  686. certPEM := pem.EncodeToMemory(
  687. &pem.Block{Type: "CERTIFICATE", Bytes: certificateDER})
  688. keyPEM := pem.EncodeToMemory(
  689. &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
  690. tlsCertificate, err := tls.X509KeyPair(certPEM, keyPEM)
  691. if err != nil {
  692. return nil, nil, nil, errors.Trace(err)
  693. }
  694. return &tlsCertificate, certPEM, keyPEM, nil
  695. }
  696. backendCertificate, backendCertificatePEM, backendKeyPEM, err :=
  697. issueCertificate("backend", true)
  698. if err != nil {
  699. return nil, errors.Trace(err)
  700. }
  701. relayCertificate, relayCertificatePEM, relayKeyPEM, err :=
  702. issueCertificate("relay", false)
  703. if err != nil {
  704. return nil, errors.Trace(err)
  705. }
  706. return &TestDSLTLSConfig{
  707. CACertificate: CACertificate,
  708. CACertificatePEM: CACertificatePEM,
  709. BackendCertificate: backendCertificate,
  710. BackendCertificatePEM: backendCertificatePEM,
  711. BackendKeyPEM: backendKeyPEM,
  712. RelayCertificate: relayCertificate,
  713. RelayCertificatePEM: relayCertificatePEM,
  714. RelayKeyPEM: relayKeyPEM,
  715. }, nil
  716. }
  717. func (config *TestDSLTLSConfig) WriteRelayFiles(dirName string) (
  718. string, string, string, error) {
  719. caCertificatesFilename := filepath.Join(
  720. dirName, "dslRelayCACert.pem")
  721. err := os.WriteFile(
  722. caCertificatesFilename,
  723. config.CACertificatePEM,
  724. 0644)
  725. if err != nil {
  726. return "", "", "", errors.Trace(err)
  727. }
  728. hostCertificateFilename := filepath.Join(
  729. dirName, "dslRelayHostCert.pem")
  730. err = os.WriteFile(
  731. hostCertificateFilename,
  732. config.RelayCertificatePEM,
  733. 0644)
  734. if err != nil {
  735. return "", "", "", errors.Trace(err)
  736. }
  737. hostKeyFilename := filepath.Join(
  738. dirName, "dslRelayHostKey.pem")
  739. err = os.WriteFile(
  740. hostKeyFilename,
  741. config.RelayKeyPEM,
  742. 0644)
  743. if err != nil {
  744. return "", "", "", errors.Trace(err)
  745. }
  746. return caCertificatesFilename,
  747. hostCertificateFilename,
  748. hostKeyFilename,
  749. nil
  750. }