shim.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /*
  2. * Copyright (c) 2025, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package dsl
  20. import (
  21. "unsafe"
  22. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  23. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  24. "github.com/fxamacker/cbor/v2"
  25. )
  26. // NewBackendTestShim returns a shim that implements
  27. // psiphon/common/internal/testutils.DSLBackendTestShim. This shim, intended
  28. // only for testing, avoids the import cycle that would result if the shared
  29. // test DSL backend imported common/dsl directly.
  30. func NewBackendTestShim() *backendTestShim {
  31. return &backendTestShim{}
  32. }
  33. type backendTestShim struct {
  34. }
  35. func (b *backendTestShim) ClientIPHeaderName() string {
  36. return PsiphonClientIPHeader
  37. }
  38. func (b *backendTestShim) ClientGeoIPDataHeaderName() string {
  39. return PsiphonClientGeoIPDataHeader
  40. }
  41. func (b *backendTestShim) ClientTunneledHeaderName() string {
  42. return PsiphonClientTunneledHeader
  43. }
  44. func (b *backendTestShim) HostIDHeaderName() string {
  45. return PsiphonHostIDHeader
  46. }
  47. func (b *backendTestShim) DiscoverServerEntriesRequestPath() string {
  48. return RequestPathDiscoverServerEntries
  49. }
  50. func (b *backendTestShim) GetServerEntriesRequestPath() string {
  51. return RequestPathGetServerEntries
  52. }
  53. func (b *backendTestShim) GetActiveOSLsRequestPath() string {
  54. return RequestPathGetActiveOSLs
  55. }
  56. func (b *backendTestShim) GetOSLFileSpecsRequestPath() string {
  57. return RequestPathGetOSLFileSpecs
  58. }
  59. func (b *backendTestShim) UnmarshalDiscoverServerEntriesRequest(
  60. cborRequest []byte) (
  61. apiParams protocol.PackedAPIParameters,
  62. oslKeys [][]byte,
  63. discoverCount int32,
  64. retErr error) {
  65. var request *DiscoverServerEntriesRequest
  66. err := cbor.Unmarshal(cborRequest, &request)
  67. if err != nil {
  68. return nil, nil, 0, errors.Trace(err)
  69. }
  70. return request.BaseAPIParameters,
  71. convertSlice[OSLKey, []byte](request.OSLKeys),
  72. request.DiscoverCount,
  73. nil
  74. }
  75. func (b *backendTestShim) MarshalDiscoverServerEntriesResponse(
  76. versionedServerEntryTags []*struct {
  77. Tag []byte
  78. Version int32
  79. PrioritizeDial bool
  80. }) (
  81. cborResponse []byte,
  82. retErr error) {
  83. response := &DiscoverServerEntriesResponse{
  84. VersionedServerEntryTags: convertSlice[
  85. *struct {
  86. Tag []byte
  87. Version int32
  88. PrioritizeDial bool
  89. }, *VersionedServerEntryTag](versionedServerEntryTags),
  90. }
  91. cborResponse, err := protocol.CBOREncoding.Marshal(response)
  92. return cborResponse, errors.Trace(err)
  93. }
  94. func (b *backendTestShim) UnmarshalGetServerEntriesRequest(
  95. cborRequest []byte) (
  96. apiParams protocol.PackedAPIParameters,
  97. serverEntryTags [][]byte,
  98. retErr error) {
  99. var request *GetServerEntriesRequest
  100. err := cbor.Unmarshal(cborRequest, &request)
  101. if err != nil {
  102. return nil, nil, errors.Trace(err)
  103. }
  104. return request.BaseAPIParameters,
  105. convertSlice[ServerEntryTag, []byte](request.ServerEntryTags),
  106. nil
  107. }
  108. func (b *backendTestShim) MarshalGetServerEntriesResponse(
  109. sourcedServerEntries []*struct {
  110. ServerEntryFields protocol.PackedServerEntryFields
  111. Source string
  112. }) (
  113. cborResponse []byte,
  114. retErr error) {
  115. response := &GetServerEntriesResponse{
  116. SourcedServerEntries: convertSlice[
  117. *struct {
  118. ServerEntryFields protocol.PackedServerEntryFields
  119. Source string
  120. }, *SourcedServerEntry](sourcedServerEntries),
  121. }
  122. cborResponse, err := protocol.CBOREncoding.Marshal(response)
  123. return cborResponse, errors.Trace(err)
  124. }
  125. func (b *backendTestShim) UnmarshalGetActiveOSLsRequest(
  126. cborRequest []byte) (
  127. apiParams protocol.PackedAPIParameters,
  128. retErr error) {
  129. var request *GetActiveOSLsRequest
  130. err := cbor.Unmarshal(cborRequest, &request)
  131. if err != nil {
  132. return nil, errors.Trace(err)
  133. }
  134. return request.BaseAPIParameters, nil
  135. }
  136. func (b *backendTestShim) MarshalGetActiveOSLsResponse(
  137. activeOSLIDs [][]byte) (
  138. cborResponse []byte,
  139. retErr error) {
  140. response := &GetActiveOSLsResponse{
  141. ActiveOSLIDs: convertSlice[[]byte, OSLID](activeOSLIDs),
  142. }
  143. cborResponse, err := protocol.CBOREncoding.Marshal(response)
  144. return cborResponse, errors.Trace(err)
  145. }
  146. func (b *backendTestShim) UnmarshalGetOSLFileSpecsRequest(
  147. cborRequest []byte) (
  148. apiParams protocol.PackedAPIParameters,
  149. oslIDs [][]byte,
  150. retErr error) {
  151. var request *GetOSLFileSpecsRequest
  152. err := cbor.Unmarshal(cborRequest, &request)
  153. if err != nil {
  154. return nil, nil, errors.Trace(err)
  155. }
  156. return request.BaseAPIParameters,
  157. convertSlice[OSLID, []byte](request.OSLIDs),
  158. nil
  159. }
  160. func (b *backendTestShim) MarshalGetOSLFileSpecsResponse(
  161. oslFileSpecs [][]byte) (
  162. cborResponse []byte,
  163. retErr error) {
  164. response := &GetOSLFileSpecsResponse{
  165. OSLFileSpecs: convertSlice[[]byte, OSLFileSpec](oslFileSpecs),
  166. }
  167. cborResponse, err := protocol.CBOREncoding.Marshal(response)
  168. return cborResponse, errors.Trace(err)
  169. }
  170. func convertSlice[A any, B any](s []A) []B {
  171. if len(s) == 0 {
  172. return []B{}
  173. }
  174. var a A
  175. var b B
  176. if unsafe.Sizeof(a) != unsafe.Sizeof(b) {
  177. panic("incompatible types")
  178. }
  179. return *(*[]B)(unsafe.Pointer(&s))
  180. }