trafficRules_test.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. /*
  2. * Copyright (c) 2022, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "encoding/json"
  22. "io/ioutil"
  23. "os"
  24. "reflect"
  25. "testing"
  26. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  27. )
  28. func TestTrafficRulesFilters(t *testing.T) {
  29. trafficRulesJSON := `
  30. {
  31. "DefaultRules" : {
  32. "RateLimits" : {
  33. "WriteUnthrottledBytes": 1,
  34. "WriteBytesPerSecond": 2,
  35. "ReadUnthrottledBytes": 3,
  36. "ReadBytesPerSecond": 4,
  37. "UnthrottleFirstTunnelOnly": true
  38. },
  39. "AllowTCPPorts" : [5],
  40. "AllowUDPPorts" : [6]
  41. },
  42. "FilteredRules" : [
  43. {
  44. "Filter" : {
  45. "ProviderIDs" : ["H2"]
  46. },
  47. "Rules" : {
  48. "RateLimits" : {
  49. "WriteBytesPerSecond": 99,
  50. "ReadBytesPerSecond": 99
  51. }
  52. }
  53. },
  54. {
  55. "Filter" : {
  56. "ProviderIDs" : ["H1"],
  57. "Regions" : ["R2"],
  58. "HandshakeParameters" : {
  59. "client_version" : ["1"]
  60. }
  61. },
  62. "Rules" : {
  63. "RateLimits" : {
  64. "WriteBytesPerSecond": 7,
  65. "ReadBytesPerSecond": 8
  66. },
  67. "AllowTCPPorts" : [5,9],
  68. "AllowUDPPorts" : [6,10]
  69. }
  70. },
  71. {
  72. "Filter" : {
  73. "TunnelProtocols" : ["P2"],
  74. "Regions" : ["R3", "R4"],
  75. "HandshakeParameters" : {
  76. "client_version" : ["1", "2"]
  77. }
  78. },
  79. "ExceptFilter" : {
  80. "ISPs" : ["I2", "I3"],
  81. "HandshakeParameters" : {
  82. "client_version" : ["1"]
  83. }
  84. },
  85. "Rules" : {
  86. "RateLimits" : {
  87. "WriteBytesPerSecond": 11,
  88. "ReadBytesPerSecond": 12
  89. },
  90. "AllowTCPPorts" : [5,13],
  91. "AllowUDPPorts" : [6,14]
  92. }
  93. },
  94. {
  95. "Filter" : {
  96. "Regions" : ["R3", "R4"],
  97. "HandshakeParameters" : {
  98. "client_version" : ["1", "2"]
  99. }
  100. },
  101. "ExceptFilter" : {
  102. "ISPs" : ["I2", "I3"],
  103. "HandshakeParameters" : {
  104. "client_version" : ["1"]
  105. }
  106. },
  107. "Rules" : {
  108. "RateLimits" : {
  109. "WriteBytesPerSecond": 15,
  110. "ReadBytesPerSecond": 16
  111. },
  112. "AllowTCPPorts" : [5,17],
  113. "AllowUDPPorts" : [6,18]
  114. }
  115. }
  116. ]
  117. }
  118. `
  119. file, err := ioutil.TempFile("", "trafficRules.config")
  120. if err != nil {
  121. t.Fatalf("TempFile create failed: %s", err)
  122. }
  123. _, err = file.Write([]byte(trafficRulesJSON))
  124. if err != nil {
  125. t.Fatalf("TempFile write failed: %s", err)
  126. }
  127. file.Close()
  128. configFileName := file.Name()
  129. defer os.Remove(configFileName)
  130. trafficRules, err := NewTrafficRulesSet(configFileName)
  131. if err != nil {
  132. t.Fatalf("NewTrafficRulesSet failed: %s", err)
  133. }
  134. err = trafficRules.Validate()
  135. if err != nil {
  136. t.Fatalf("TrafficRulesSet.Validate failed: %s", err)
  137. }
  138. makePortList := func(portsJSON string) common.PortList {
  139. var p common.PortList
  140. _ = json.Unmarshal([]byte(portsJSON), &p)
  141. return p
  142. }
  143. // should never get 1st filtered rule with different provider ID
  144. providerID := "H1"
  145. testCases := []struct {
  146. description string
  147. providerID string
  148. isFirstTunnelInSession bool
  149. tunnelProtocol string
  150. geoIPData GeoIPData
  151. state handshakeState
  152. expectedWriteUnthrottledBytes int64
  153. expectedWriteBytesPerSecond int64
  154. expectedReadUnthrottledBytes int64
  155. expectedReadBytesPerSecond int64
  156. expectedAllowTCPPorts common.PortList
  157. expectedAllowUDPPorts common.PortList
  158. }{
  159. {
  160. "get defaults",
  161. providerID,
  162. true,
  163. "P1",
  164. GeoIPData{Country: "R1", ISP: "I1"},
  165. handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
  166. 1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
  167. },
  168. {
  169. "get defaults for not first tunnel in session",
  170. providerID,
  171. false,
  172. "P1",
  173. GeoIPData{Country: "R1", ISP: "I1"},
  174. handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
  175. 0, 2, 0, 4, makePortList("[5]"), makePortList("[6]"),
  176. },
  177. {
  178. "get 2nd filtered rule (including provider ID)",
  179. providerID,
  180. true,
  181. "P1",
  182. GeoIPData{Country: "R2", ISP: "I1"},
  183. handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
  184. 1, 7, 3, 8, makePortList("[5,9]"), makePortList("[6,10]"),
  185. },
  186. {
  187. "don't get 2nd filtered rule with incomplete match",
  188. providerID,
  189. true,
  190. "P1",
  191. GeoIPData{Country: "R2", ISP: "I1"},
  192. handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
  193. 1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
  194. },
  195. {
  196. "get 3rd filtered rule",
  197. providerID,
  198. true,
  199. "P2",
  200. GeoIPData{Country: "R3", ISP: "I1"},
  201. handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
  202. 1, 11, 3, 12, makePortList("[5,13]"), makePortList("[6,14]"),
  203. },
  204. {
  205. "get 3rd filtered rule with incomplete exception",
  206. providerID,
  207. true,
  208. "P2",
  209. GeoIPData{Country: "R3", ISP: "I2"},
  210. handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
  211. 1, 11, 3, 12, makePortList("[5,13]"), makePortList("[6,14]"),
  212. },
  213. {
  214. "don't get 3rd filtered rule due to exception",
  215. providerID,
  216. true,
  217. "P2",
  218. GeoIPData{Country: "R3", ISP: "I2"},
  219. handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
  220. 1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
  221. },
  222. {
  223. "get 4th filtered rule",
  224. providerID,
  225. true,
  226. "P1",
  227. GeoIPData{Country: "R3", ISP: "I1"},
  228. handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
  229. 1, 15, 3, 16, makePortList("[5,17]"), makePortList("[6,18]"),
  230. },
  231. {
  232. "don't get 4th filtered rule due to exception",
  233. providerID,
  234. true,
  235. "P1",
  236. GeoIPData{Country: "R3", ISP: "I2"},
  237. handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
  238. 1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
  239. },
  240. }
  241. for _, testCase := range testCases {
  242. t.Run(testCase.description, func(t *testing.T) {
  243. rules := trafficRules.GetTrafficRules(
  244. testCase.providerID,
  245. testCase.isFirstTunnelInSession,
  246. testCase.tunnelProtocol,
  247. testCase.geoIPData,
  248. testCase.state)
  249. if *rules.RateLimits.WriteUnthrottledBytes != testCase.expectedWriteUnthrottledBytes {
  250. t.Errorf("unexpected rules.RateLimits.WriteUnthrottledBytes: %v != %v",
  251. *rules.RateLimits.WriteUnthrottledBytes, testCase.expectedWriteUnthrottledBytes)
  252. }
  253. if *rules.RateLimits.WriteBytesPerSecond != testCase.expectedWriteBytesPerSecond {
  254. t.Errorf("unexpected rules.RateLimits.WriteBytesPerSecond: %v != %v",
  255. *rules.RateLimits.WriteBytesPerSecond, testCase.expectedWriteBytesPerSecond)
  256. }
  257. if *rules.RateLimits.ReadUnthrottledBytes != testCase.expectedReadUnthrottledBytes {
  258. t.Errorf("unexpected rules.RateLimits.ReadUnthrottledBytes: %v != %v",
  259. *rules.RateLimits.ReadUnthrottledBytes, testCase.expectedReadUnthrottledBytes)
  260. }
  261. if *rules.RateLimits.ReadBytesPerSecond != testCase.expectedReadBytesPerSecond {
  262. t.Errorf("unexpected rules.RateLimits.ReadBytesPerSecond: %v != %v",
  263. *rules.RateLimits.ReadBytesPerSecond, testCase.expectedReadBytesPerSecond)
  264. }
  265. if !reflect.DeepEqual(*rules.AllowTCPPorts, testCase.expectedAllowTCPPorts) {
  266. t.Errorf("unexpected rules.RateLimits.AllowTCPPorts: %v != %v",
  267. *rules.AllowTCPPorts, testCase.expectedAllowTCPPorts)
  268. }
  269. if !reflect.DeepEqual(*rules.AllowUDPPorts, testCase.expectedAllowUDPPorts) {
  270. t.Errorf("unexpected rules.RateLimits.AllowUDPPorts: %v != %v",
  271. *rules.AllowUDPPorts, testCase.expectedAllowUDPPorts)
  272. }
  273. })
  274. }
  275. }