mux_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "sync"
  10. "testing"
  11. )
  12. // [Psiphon]
  13. // See comment in channel.go
  14. var testChannelWindowSize = getChannelWindowSize("")
  15. func muxPair() (*mux, *mux) {
  16. a, b := memPipe()
  17. s := newMux(a)
  18. c := newMux(b)
  19. return s, c
  20. }
  21. // Returns both ends of a channel, and the mux for the 2nd
  22. // channel.
  23. func channelPair(t *testing.T) (*channel, *channel, *mux) {
  24. c, s := muxPair()
  25. res := make(chan *channel, 1)
  26. go func() {
  27. newCh, ok := <-s.incomingChannels
  28. if !ok {
  29. t.Error("no incoming channel")
  30. close(res)
  31. return
  32. }
  33. if newCh.ChannelType() != "chan" {
  34. t.Errorf("got type %q want chan", newCh.ChannelType())
  35. newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType()))
  36. close(res)
  37. return
  38. }
  39. ch, _, err := newCh.Accept()
  40. if err != nil {
  41. t.Errorf("accept: %v", err)
  42. close(res)
  43. return
  44. }
  45. res <- ch.(*channel)
  46. }()
  47. ch, err := c.openChannel("chan", nil)
  48. if err != nil {
  49. t.Fatalf("OpenChannel: %v", err)
  50. }
  51. w := <-res
  52. if w == nil {
  53. t.Fatal("unable to get write channel")
  54. }
  55. return w, ch, c
  56. }
  57. // Test that stderr and stdout can be addressed from different
  58. // goroutines. This is intended for use with the race detector.
  59. func TestMuxChannelExtendedThreadSafety(t *testing.T) {
  60. writer, reader, mux := channelPair(t)
  61. defer writer.Close()
  62. defer reader.Close()
  63. defer mux.Close()
  64. var wr, rd sync.WaitGroup
  65. magic := "hello world"
  66. wr.Add(2)
  67. go func() {
  68. io.WriteString(writer, magic)
  69. wr.Done()
  70. }()
  71. go func() {
  72. io.WriteString(writer.Stderr(), magic)
  73. wr.Done()
  74. }()
  75. rd.Add(2)
  76. go func() {
  77. c, err := io.ReadAll(reader)
  78. if string(c) != magic {
  79. t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err)
  80. }
  81. rd.Done()
  82. }()
  83. go func() {
  84. c, err := io.ReadAll(reader.Stderr())
  85. if string(c) != magic {
  86. t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err)
  87. }
  88. rd.Done()
  89. }()
  90. wr.Wait()
  91. writer.CloseWrite()
  92. rd.Wait()
  93. }
  94. func TestMuxReadWrite(t *testing.T) {
  95. s, c, mux := channelPair(t)
  96. defer s.Close()
  97. defer c.Close()
  98. defer mux.Close()
  99. magic := "hello world"
  100. magicExt := "hello stderr"
  101. var wg sync.WaitGroup
  102. t.Cleanup(wg.Wait)
  103. wg.Add(1)
  104. go func() {
  105. defer wg.Done()
  106. _, err := s.Write([]byte(magic))
  107. if err != nil {
  108. t.Errorf("Write: %v", err)
  109. return
  110. }
  111. _, err = s.Extended(1).Write([]byte(magicExt))
  112. if err != nil {
  113. t.Errorf("Write: %v", err)
  114. return
  115. }
  116. }()
  117. var buf [1024]byte
  118. n, err := c.Read(buf[:])
  119. if err != nil {
  120. t.Fatalf("server Read: %v", err)
  121. }
  122. got := string(buf[:n])
  123. if got != magic {
  124. t.Fatalf("server: got %q want %q", got, magic)
  125. }
  126. n, err = c.Extended(1).Read(buf[:])
  127. if err != nil {
  128. t.Fatalf("server Read: %v", err)
  129. }
  130. got = string(buf[:n])
  131. if got != magicExt {
  132. t.Fatalf("server: got %q want %q", got, magic)
  133. }
  134. }
  135. func TestMuxChannelOverflow(t *testing.T) {
  136. reader, writer, mux := channelPair(t)
  137. defer reader.Close()
  138. defer writer.Close()
  139. defer mux.Close()
  140. var wg sync.WaitGroup
  141. t.Cleanup(wg.Wait)
  142. wg.Add(1)
  143. go func() {
  144. defer wg.Done()
  145. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  146. t.Errorf("could not fill window: %v", err)
  147. }
  148. writer.Write(make([]byte, 1))
  149. }()
  150. writer.remoteWin.waitWriterBlocked()
  151. // Send 1 byte.
  152. packet := make([]byte, 1+4+4+1)
  153. packet[0] = msgChannelData
  154. marshalUint32(packet[1:], writer.remoteId)
  155. marshalUint32(packet[5:], uint32(1))
  156. packet[9] = 42
  157. if err := writer.mux.conn.writePacket(packet); err != nil {
  158. t.Errorf("could not send packet")
  159. }
  160. if _, err := reader.SendRequest("hello", true, nil); err == nil {
  161. t.Errorf("SendRequest succeeded.")
  162. }
  163. }
  164. func TestMuxChannelReadUnblock(t *testing.T) {
  165. reader, writer, mux := channelPair(t)
  166. defer reader.Close()
  167. defer writer.Close()
  168. defer mux.Close()
  169. var wg sync.WaitGroup
  170. t.Cleanup(wg.Wait)
  171. wg.Add(1)
  172. go func() {
  173. defer wg.Done()
  174. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  175. t.Errorf("could not fill window: %v", err)
  176. }
  177. if _, err := writer.Write(make([]byte, 1)); err != nil {
  178. t.Errorf("Write: %v", err)
  179. }
  180. writer.Close()
  181. }()
  182. writer.remoteWin.waitWriterBlocked()
  183. buf := make([]byte, 32768)
  184. for {
  185. _, err := reader.Read(buf)
  186. if err == io.EOF {
  187. break
  188. }
  189. if err != nil {
  190. t.Fatalf("Read: %v", err)
  191. }
  192. }
  193. }
  194. func TestMuxChannelCloseWriteUnblock(t *testing.T) {
  195. reader, writer, mux := channelPair(t)
  196. defer reader.Close()
  197. defer writer.Close()
  198. defer mux.Close()
  199. var wg sync.WaitGroup
  200. t.Cleanup(wg.Wait)
  201. wg.Add(1)
  202. go func() {
  203. defer wg.Done()
  204. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  205. t.Errorf("could not fill window: %v", err)
  206. }
  207. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  208. t.Errorf("got %v, want EOF for unblock write", err)
  209. }
  210. }()
  211. writer.remoteWin.waitWriterBlocked()
  212. reader.Close()
  213. }
  214. func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
  215. reader, writer, mux := channelPair(t)
  216. defer reader.Close()
  217. defer writer.Close()
  218. defer mux.Close()
  219. var wg sync.WaitGroup
  220. t.Cleanup(wg.Wait)
  221. wg.Add(1)
  222. go func() {
  223. defer wg.Done()
  224. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  225. t.Errorf("could not fill window: %v", err)
  226. }
  227. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  228. t.Errorf("got %v, want EOF for unblock write", err)
  229. }
  230. }()
  231. writer.remoteWin.waitWriterBlocked()
  232. mux.Close()
  233. }
  234. func TestMuxReject(t *testing.T) {
  235. client, server := muxPair()
  236. defer server.Close()
  237. defer client.Close()
  238. var wg sync.WaitGroup
  239. t.Cleanup(wg.Wait)
  240. wg.Add(1)
  241. go func() {
  242. defer wg.Done()
  243. ch, ok := <-server.incomingChannels
  244. if !ok {
  245. t.Error("cannot accept channel")
  246. return
  247. }
  248. if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
  249. t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
  250. ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String())
  251. return
  252. }
  253. ch.Reject(RejectionReason(42), "message")
  254. }()
  255. ch, err := client.openChannel("ch", []byte("extra"))
  256. if ch != nil {
  257. t.Fatal("openChannel not rejected")
  258. }
  259. ocf, ok := err.(*OpenChannelError)
  260. if !ok {
  261. t.Errorf("got %#v want *OpenChannelError", err)
  262. } else if ocf.Reason != 42 || ocf.Message != "message" {
  263. t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
  264. }
  265. want := "ssh: rejected: unknown reason 42 (message)"
  266. if err.Error() != want {
  267. t.Errorf("got %q, want %q", err.Error(), want)
  268. }
  269. }
  270. func TestMuxChannelRequest(t *testing.T) {
  271. client, server, mux := channelPair(t)
  272. defer server.Close()
  273. defer client.Close()
  274. defer mux.Close()
  275. var received int
  276. var wg sync.WaitGroup
  277. t.Cleanup(wg.Wait)
  278. wg.Add(1)
  279. go func() {
  280. for r := range server.incomingRequests {
  281. received++
  282. r.Reply(r.Type == "yes", nil)
  283. }
  284. wg.Done()
  285. }()
  286. _, err := client.SendRequest("yes", false, nil)
  287. if err != nil {
  288. t.Fatalf("SendRequest: %v", err)
  289. }
  290. ok, err := client.SendRequest("yes", true, nil)
  291. if err != nil {
  292. t.Fatalf("SendRequest: %v", err)
  293. }
  294. if !ok {
  295. t.Errorf("SendRequest(yes): %v", ok)
  296. }
  297. ok, err = client.SendRequest("no", true, nil)
  298. if err != nil {
  299. t.Fatalf("SendRequest: %v", err)
  300. }
  301. if ok {
  302. t.Errorf("SendRequest(no): %v", ok)
  303. }
  304. client.Close()
  305. wg.Wait()
  306. if received != 3 {
  307. t.Errorf("got %d requests, want %d", received, 3)
  308. }
  309. }
  310. func TestMuxUnknownChannelRequests(t *testing.T) {
  311. clientPipe, serverPipe := memPipe()
  312. client := newMux(clientPipe)
  313. defer serverPipe.Close()
  314. defer client.Close()
  315. kDone := make(chan error, 1)
  316. go func() {
  317. // Ignore unknown channel messages that don't want a reply.
  318. err := serverPipe.writePacket(Marshal(channelRequestMsg{
  319. PeersID: 1,
  320. Request: "keepalive@openssh.com",
  321. WantReply: false,
  322. RequestSpecificData: []byte{},
  323. }))
  324. if err != nil {
  325. kDone <- fmt.Errorf("send: %w", err)
  326. return
  327. }
  328. // Send a keepalive, which should get a channel failure message
  329. // in response.
  330. err = serverPipe.writePacket(Marshal(channelRequestMsg{
  331. PeersID: 2,
  332. Request: "keepalive@openssh.com",
  333. WantReply: true,
  334. RequestSpecificData: []byte{},
  335. }))
  336. if err != nil {
  337. kDone <- fmt.Errorf("send: %w", err)
  338. return
  339. }
  340. packet, err := serverPipe.readPacket()
  341. if err != nil {
  342. kDone <- fmt.Errorf("read packet: %w", err)
  343. return
  344. }
  345. decoded, err := decode(packet)
  346. if err != nil {
  347. kDone <- fmt.Errorf("decode failed: %w", err)
  348. return
  349. }
  350. switch msg := decoded.(type) {
  351. case *channelRequestFailureMsg:
  352. if msg.PeersID != 2 {
  353. kDone <- fmt.Errorf("received response to wrong message: %v", msg)
  354. return
  355. }
  356. default:
  357. kDone <- fmt.Errorf("unexpected channel message: %v", msg)
  358. return
  359. }
  360. kDone <- nil
  361. // Receive and respond to the keepalive to confirm the mux is
  362. // still processing requests.
  363. packet, err = serverPipe.readPacket()
  364. if err != nil {
  365. kDone <- fmt.Errorf("read packet: %w", err)
  366. return
  367. }
  368. if packet[0] != msgGlobalRequest {
  369. kDone <- errors.New("expected global request")
  370. return
  371. }
  372. err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
  373. Data: []byte{},
  374. }))
  375. if err != nil {
  376. kDone <- fmt.Errorf("failed to send failure msg: %w", err)
  377. return
  378. }
  379. close(kDone)
  380. }()
  381. // Wait for the server to send the keepalive message and receive back a
  382. // response.
  383. if err := <-kDone; err != nil {
  384. t.Fatal(err)
  385. }
  386. // Confirm client hasn't closed.
  387. if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
  388. t.Fatalf("failed to send keepalive: %v", err)
  389. }
  390. // Wait for the server to shut down.
  391. if err := <-kDone; err != nil {
  392. t.Fatal(err)
  393. }
  394. }
  395. func TestMuxClosedChannel(t *testing.T) {
  396. clientPipe, serverPipe := memPipe()
  397. client := newMux(clientPipe)
  398. defer serverPipe.Close()
  399. defer client.Close()
  400. kDone := make(chan error, 1)
  401. go func() {
  402. // Open the channel.
  403. packet, err := serverPipe.readPacket()
  404. if err != nil {
  405. kDone <- fmt.Errorf("read packet: %w", err)
  406. return
  407. }
  408. if packet[0] != msgChannelOpen {
  409. kDone <- errors.New("expected chan open")
  410. return
  411. }
  412. var openMsg channelOpenMsg
  413. if err := Unmarshal(packet, &openMsg); err != nil {
  414. kDone <- fmt.Errorf("unmarshal: %w", err)
  415. return
  416. }
  417. // Send back the opened channel confirmation.
  418. err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{
  419. PeersID: openMsg.PeersID,
  420. MyID: 0,
  421. MyWindow: 0,
  422. MaxPacketSize: channelMaxPacket,
  423. }))
  424. if err != nil {
  425. kDone <- fmt.Errorf("send: %w", err)
  426. return
  427. }
  428. // Close the channel.
  429. err = serverPipe.writePacket(Marshal(channelCloseMsg{
  430. PeersID: openMsg.PeersID,
  431. }))
  432. if err != nil {
  433. kDone <- fmt.Errorf("send: %w", err)
  434. return
  435. }
  436. // Send a keepalive message on the channel we just closed.
  437. err = serverPipe.writePacket(Marshal(channelRequestMsg{
  438. PeersID: openMsg.PeersID,
  439. Request: "keepalive@openssh.com",
  440. WantReply: true,
  441. RequestSpecificData: []byte{},
  442. }))
  443. if err != nil {
  444. kDone <- fmt.Errorf("send: %w", err)
  445. return
  446. }
  447. // Receive the channel closed response.
  448. packet, err = serverPipe.readPacket()
  449. if err != nil {
  450. kDone <- fmt.Errorf("read packet: %w", err)
  451. return
  452. }
  453. if packet[0] != msgChannelClose {
  454. kDone <- errors.New("expected channel close")
  455. return
  456. }
  457. // Receive the keepalive response failure.
  458. packet, err = serverPipe.readPacket()
  459. if err != nil {
  460. kDone <- fmt.Errorf("read packet: %w", err)
  461. return
  462. }
  463. if packet[0] != msgChannelFailure {
  464. kDone <- errors.New("expected channel failure")
  465. return
  466. }
  467. kDone <- nil
  468. // Receive and respond to the keepalive to confirm the mux is
  469. // still processing requests.
  470. packet, err = serverPipe.readPacket()
  471. if err != nil {
  472. kDone <- fmt.Errorf("read packet: %w", err)
  473. return
  474. }
  475. if packet[0] != msgGlobalRequest {
  476. kDone <- errors.New("expected global request")
  477. return
  478. }
  479. err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
  480. Data: []byte{},
  481. }))
  482. if err != nil {
  483. kDone <- fmt.Errorf("failed to send failure msg: %w", err)
  484. return
  485. }
  486. close(kDone)
  487. }()
  488. // Open a channel.
  489. ch, err := client.openChannel("chan", nil)
  490. if err != nil {
  491. t.Fatalf("OpenChannel: %v", err)
  492. }
  493. defer ch.Close()
  494. // Wait for the server to close the channel and send the keepalive.
  495. <-kDone
  496. // Make sure the channel closed.
  497. if _, ok := <-ch.incomingRequests; ok {
  498. t.Fatalf("channel not closed")
  499. }
  500. // Confirm client hasn't closed
  501. if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
  502. t.Fatalf("failed to send keepalive: %v", err)
  503. }
  504. // Wait for the server to shut down.
  505. <-kDone
  506. }
  507. func TestMuxGlobalRequest(t *testing.T) {
  508. var sawPeek bool
  509. var wg sync.WaitGroup
  510. defer func() {
  511. wg.Wait()
  512. if !sawPeek {
  513. t.Errorf("never saw 'peek' request")
  514. }
  515. }()
  516. clientMux, serverMux := muxPair()
  517. defer serverMux.Close()
  518. defer clientMux.Close()
  519. wg.Add(1)
  520. go func() {
  521. defer wg.Done()
  522. for r := range serverMux.incomingRequests {
  523. sawPeek = sawPeek || r.Type == "peek"
  524. if r.WantReply {
  525. err := r.Reply(r.Type == "yes",
  526. append([]byte(r.Type), r.Payload...))
  527. if err != nil {
  528. t.Errorf("AckRequest: %v", err)
  529. }
  530. }
  531. }
  532. }()
  533. _, _, err := clientMux.SendRequest("peek", false, nil)
  534. if err != nil {
  535. t.Errorf("SendRequest: %v", err)
  536. }
  537. ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
  538. if !ok || string(data) != "yesa" || err != nil {
  539. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  540. ok, data, err)
  541. }
  542. if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
  543. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  544. ok, data, err)
  545. }
  546. if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
  547. t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
  548. ok, data, err)
  549. }
  550. }
  551. func TestMuxGlobalRequestUnblock(t *testing.T) {
  552. clientMux, serverMux := muxPair()
  553. defer serverMux.Close()
  554. defer clientMux.Close()
  555. result := make(chan error, 1)
  556. go func() {
  557. _, _, err := clientMux.SendRequest("hello", true, nil)
  558. result <- err
  559. }()
  560. <-serverMux.incomingRequests
  561. serverMux.conn.Close()
  562. err := <-result
  563. if err != io.EOF {
  564. t.Errorf("want EOF, got %v", io.EOF)
  565. }
  566. }
  567. func TestMuxChannelRequestUnblock(t *testing.T) {
  568. a, b, connB := channelPair(t)
  569. defer a.Close()
  570. defer b.Close()
  571. defer connB.Close()
  572. result := make(chan error, 1)
  573. go func() {
  574. _, err := a.SendRequest("hello", true, nil)
  575. result <- err
  576. }()
  577. <-b.incomingRequests
  578. connB.conn.Close()
  579. err := <-result
  580. if err != io.EOF {
  581. t.Errorf("want EOF, got %v", err)
  582. }
  583. }
  584. func TestMuxCloseChannel(t *testing.T) {
  585. r, w, mux := channelPair(t)
  586. defer mux.Close()
  587. defer r.Close()
  588. defer w.Close()
  589. result := make(chan error, 1)
  590. go func() {
  591. var b [1024]byte
  592. _, err := r.Read(b[:])
  593. result <- err
  594. }()
  595. if err := w.Close(); err != nil {
  596. t.Errorf("w.Close: %v", err)
  597. }
  598. if _, err := w.Write([]byte("hello")); err != io.EOF {
  599. t.Errorf("got err %v, want io.EOF after Close", err)
  600. }
  601. if err := <-result; err != io.EOF {
  602. t.Errorf("got %v (%T), want io.EOF", err, err)
  603. }
  604. }
  605. func TestMuxCloseWriteChannel(t *testing.T) {
  606. r, w, mux := channelPair(t)
  607. defer mux.Close()
  608. result := make(chan error, 1)
  609. go func() {
  610. var b [1024]byte
  611. _, err := r.Read(b[:])
  612. result <- err
  613. }()
  614. if err := w.CloseWrite(); err != nil {
  615. t.Errorf("w.CloseWrite: %v", err)
  616. }
  617. if _, err := w.Write([]byte("hello")); err != io.EOF {
  618. t.Errorf("got err %v, want io.EOF after CloseWrite", err)
  619. }
  620. if err := <-result; err != io.EOF {
  621. t.Errorf("got %v (%T), want io.EOF", err, err)
  622. }
  623. }
  624. func TestMuxInvalidRecord(t *testing.T) {
  625. a, b := muxPair()
  626. defer a.Close()
  627. defer b.Close()
  628. packet := make([]byte, 1+4+4+1)
  629. packet[0] = msgChannelData
  630. marshalUint32(packet[1:], 29348723 /* invalid channel id */)
  631. marshalUint32(packet[5:], 1)
  632. packet[9] = 42
  633. a.conn.writePacket(packet)
  634. go a.SendRequest("hello", false, nil)
  635. // 'a' wrote an invalid packet, so 'b' has exited.
  636. req, ok := <-b.incomingRequests
  637. if ok {
  638. t.Errorf("got request %#v after receiving invalid packet", req)
  639. }
  640. }
  641. func TestZeroWindowAdjust(t *testing.T) {
  642. a, b, mux := channelPair(t)
  643. defer a.Close()
  644. defer b.Close()
  645. defer mux.Close()
  646. go func() {
  647. io.WriteString(a, "hello")
  648. // bogus adjust.
  649. a.sendMessage(windowAdjustMsg{})
  650. io.WriteString(a, "world")
  651. a.Close()
  652. }()
  653. want := "helloworld"
  654. c, _ := io.ReadAll(b)
  655. if string(c) != want {
  656. t.Errorf("got %q want %q", c, want)
  657. }
  658. }
  659. func TestMuxMaxPacketSize(t *testing.T) {
  660. a, b, mux := channelPair(t)
  661. defer a.Close()
  662. defer b.Close()
  663. defer mux.Close()
  664. large := make([]byte, a.maxRemotePayload+1)
  665. packet := make([]byte, 1+4+4+1+len(large))
  666. packet[0] = msgChannelData
  667. marshalUint32(packet[1:], a.remoteId)
  668. marshalUint32(packet[5:], uint32(len(large)))
  669. packet[9] = 42
  670. if err := a.mux.conn.writePacket(packet); err != nil {
  671. t.Errorf("could not send packet")
  672. }
  673. var wg sync.WaitGroup
  674. t.Cleanup(wg.Wait)
  675. wg.Add(1)
  676. go func() {
  677. a.SendRequest("hello", false, nil)
  678. wg.Done()
  679. }()
  680. _, ok := <-b.incomingRequests
  681. if ok {
  682. t.Errorf("connection still alive after receiving large packet.")
  683. }
  684. }
  685. func TestMuxChannelWindowDeferredUpdates(t *testing.T) {
  686. s, c, mux := channelPair(t)
  687. cTransport := mux.conn.(*memTransport)
  688. defer s.Close()
  689. defer c.Close()
  690. defer mux.Close()
  691. var wg sync.WaitGroup
  692. t.Cleanup(wg.Wait)
  693. data := make([]byte, 1024)
  694. wg.Add(1)
  695. go func() {
  696. defer wg.Done()
  697. _, err := s.Write(data)
  698. if err != nil {
  699. t.Errorf("Write: %v", err)
  700. return
  701. }
  702. }()
  703. cWritesInit := cTransport.getWriteCount()
  704. buf := make([]byte, 1)
  705. for i := 0; i < len(data); i++ {
  706. n, err := c.Read(buf)
  707. if n != len(buf) || err != nil {
  708. t.Fatalf("Read: %v, %v", n, err)
  709. }
  710. }
  711. cWrites := cTransport.getWriteCount() - cWritesInit
  712. // reading 1 KiB should not cause any window updates to be sent, but allow
  713. // for some unexpected writes
  714. if cWrites > 30 {
  715. t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites)
  716. }
  717. }
  718. // Don't ship code with debug=true.
  719. func TestDebug(t *testing.T) {
  720. if debugMux {
  721. t.Error("mux debug switched on")
  722. }
  723. if debugHandshake {
  724. t.Error("handshake debug switched on")
  725. }
  726. if debugTransport {
  727. t.Error("transport debug switched on")
  728. }
  729. }