mux_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "io/ioutil"
  8. "sync"
  9. "testing"
  10. "time"
  11. )
  12. // PSIPHON
  13. // =======
  14. // See comment in channel.go
  15. var testChannelWindowSize = getChannelWindowSize("")
  16. func muxPair() (*mux, *mux) {
  17. a, b := memPipe()
  18. s := newMux(a)
  19. c := newMux(b)
  20. return s, c
  21. }
  22. // Returns both ends of a channel, and the mux for the 2nd
  23. // channel.
  24. func channelPair(t *testing.T) (*channel, *channel, *mux) {
  25. c, s := muxPair()
  26. res := make(chan *channel, 1)
  27. go func() {
  28. newCh, ok := <-s.incomingChannels
  29. if !ok {
  30. t.Fatalf("No incoming channel")
  31. }
  32. if newCh.ChannelType() != "chan" {
  33. t.Fatalf("got type %q want chan", newCh.ChannelType())
  34. }
  35. ch, _, err := newCh.Accept()
  36. if err != nil {
  37. t.Fatalf("Accept %v", err)
  38. }
  39. res <- ch.(*channel)
  40. }()
  41. ch, err := c.openChannel("chan", nil)
  42. if err != nil {
  43. t.Fatalf("OpenChannel: %v", err)
  44. }
  45. return <-res, ch, c
  46. }
  47. // Test that stderr and stdout can be addressed from different
  48. // goroutines. This is intended for use with the race detector.
  49. func TestMuxChannelExtendedThreadSafety(t *testing.T) {
  50. writer, reader, mux := channelPair(t)
  51. defer writer.Close()
  52. defer reader.Close()
  53. defer mux.Close()
  54. var wr, rd sync.WaitGroup
  55. magic := "hello world"
  56. wr.Add(2)
  57. go func() {
  58. io.WriteString(writer, magic)
  59. wr.Done()
  60. }()
  61. go func() {
  62. io.WriteString(writer.Stderr(), magic)
  63. wr.Done()
  64. }()
  65. rd.Add(2)
  66. go func() {
  67. c, err := ioutil.ReadAll(reader)
  68. if string(c) != magic {
  69. t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
  70. }
  71. rd.Done()
  72. }()
  73. go func() {
  74. c, err := ioutil.ReadAll(reader.Stderr())
  75. if string(c) != magic {
  76. t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
  77. }
  78. rd.Done()
  79. }()
  80. wr.Wait()
  81. writer.CloseWrite()
  82. rd.Wait()
  83. }
  84. func TestMuxReadWrite(t *testing.T) {
  85. s, c, mux := channelPair(t)
  86. defer s.Close()
  87. defer c.Close()
  88. defer mux.Close()
  89. magic := "hello world"
  90. magicExt := "hello stderr"
  91. go func() {
  92. _, err := s.Write([]byte(magic))
  93. if err != nil {
  94. t.Fatalf("Write: %v", err)
  95. }
  96. _, err = s.Extended(1).Write([]byte(magicExt))
  97. if err != nil {
  98. t.Fatalf("Write: %v", err)
  99. }
  100. }()
  101. var buf [1024]byte
  102. n, err := c.Read(buf[:])
  103. if err != nil {
  104. t.Fatalf("server Read: %v", err)
  105. }
  106. got := string(buf[:n])
  107. if got != magic {
  108. t.Fatalf("server: got %q want %q", got, magic)
  109. }
  110. n, err = c.Extended(1).Read(buf[:])
  111. if err != nil {
  112. t.Fatalf("server Read: %v", err)
  113. }
  114. got = string(buf[:n])
  115. if got != magicExt {
  116. t.Fatalf("server: got %q want %q", got, magic)
  117. }
  118. }
  119. func TestMuxChannelOverflow(t *testing.T) {
  120. reader, writer, mux := channelPair(t)
  121. defer reader.Close()
  122. defer writer.Close()
  123. defer mux.Close()
  124. wDone := make(chan int, 1)
  125. go func() {
  126. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  127. t.Errorf("could not fill window: %v", err)
  128. }
  129. writer.Write(make([]byte, 1))
  130. wDone <- 1
  131. }()
  132. writer.remoteWin.waitWriterBlocked()
  133. // Send 1 byte.
  134. packet := make([]byte, 1+4+4+1)
  135. packet[0] = msgChannelData
  136. marshalUint32(packet[1:], writer.remoteId)
  137. marshalUint32(packet[5:], uint32(1))
  138. packet[9] = 42
  139. if err := writer.mux.conn.writePacket(packet); err != nil {
  140. t.Errorf("could not send packet")
  141. }
  142. if _, err := reader.SendRequest("hello", true, nil); err == nil {
  143. t.Errorf("SendRequest succeeded.")
  144. }
  145. <-wDone
  146. }
  147. func TestMuxChannelCloseWriteUnblock(t *testing.T) {
  148. reader, writer, mux := channelPair(t)
  149. defer reader.Close()
  150. defer writer.Close()
  151. defer mux.Close()
  152. wDone := make(chan int, 1)
  153. go func() {
  154. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  155. t.Errorf("could not fill window: %v", err)
  156. }
  157. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  158. t.Errorf("got %v, want EOF for unblock write", err)
  159. }
  160. wDone <- 1
  161. }()
  162. writer.remoteWin.waitWriterBlocked()
  163. reader.Close()
  164. <-wDone
  165. }
  166. func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
  167. reader, writer, mux := channelPair(t)
  168. defer reader.Close()
  169. defer writer.Close()
  170. defer mux.Close()
  171. wDone := make(chan int, 1)
  172. go func() {
  173. if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
  174. t.Errorf("could not fill window: %v", err)
  175. }
  176. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  177. t.Errorf("got %v, want EOF for unblock write", err)
  178. }
  179. wDone <- 1
  180. }()
  181. writer.remoteWin.waitWriterBlocked()
  182. mux.Close()
  183. <-wDone
  184. }
  185. func TestMuxReject(t *testing.T) {
  186. client, server := muxPair()
  187. defer server.Close()
  188. defer client.Close()
  189. go func() {
  190. ch, ok := <-server.incomingChannels
  191. if !ok {
  192. t.Fatalf("Accept")
  193. }
  194. if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
  195. t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
  196. }
  197. ch.Reject(RejectionReason(42), "message")
  198. }()
  199. ch, err := client.openChannel("ch", []byte("extra"))
  200. if ch != nil {
  201. t.Fatal("openChannel not rejected")
  202. }
  203. ocf, ok := err.(*OpenChannelError)
  204. if !ok {
  205. t.Errorf("got %#v want *OpenChannelError", err)
  206. } else if ocf.Reason != 42 || ocf.Message != "message" {
  207. t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
  208. }
  209. want := "ssh: rejected: unknown reason 42 (message)"
  210. if err.Error() != want {
  211. t.Errorf("got %q, want %q", err.Error(), want)
  212. }
  213. }
  214. func TestMuxChannelRequest(t *testing.T) {
  215. client, server, mux := channelPair(t)
  216. defer server.Close()
  217. defer client.Close()
  218. defer mux.Close()
  219. var received int
  220. var wg sync.WaitGroup
  221. wg.Add(1)
  222. go func() {
  223. for r := range server.incomingRequests {
  224. received++
  225. r.Reply(r.Type == "yes", nil)
  226. }
  227. wg.Done()
  228. }()
  229. _, err := client.SendRequest("yes", false, nil)
  230. if err != nil {
  231. t.Fatalf("SendRequest: %v", err)
  232. }
  233. ok, err := client.SendRequest("yes", true, nil)
  234. if err != nil {
  235. t.Fatalf("SendRequest: %v", err)
  236. }
  237. if !ok {
  238. t.Errorf("SendRequest(yes): %v", ok)
  239. }
  240. ok, err = client.SendRequest("no", true, nil)
  241. if err != nil {
  242. t.Fatalf("SendRequest: %v", err)
  243. }
  244. if ok {
  245. t.Errorf("SendRequest(no): %v", ok)
  246. }
  247. client.Close()
  248. wg.Wait()
  249. if received != 3 {
  250. t.Errorf("got %d requests, want %d", received, 3)
  251. }
  252. }
  253. func TestMuxUnknownChannelRequests(t *testing.T) {
  254. clientPipe, serverPipe := memPipe()
  255. client := newMux(clientPipe)
  256. defer serverPipe.Close()
  257. defer client.Close()
  258. kDone := make(chan struct{})
  259. go func() {
  260. // Ignore unknown channel messages that don't want a reply.
  261. err := serverPipe.writePacket(Marshal(channelRequestMsg{
  262. PeersID: 1,
  263. Request: "keepalive@openssh.com",
  264. WantReply: false,
  265. RequestSpecificData: []byte{},
  266. }))
  267. if err != nil {
  268. t.Fatalf("send: %v", err)
  269. }
  270. // Send a keepalive, which should get a channel failure message
  271. // in response.
  272. err = serverPipe.writePacket(Marshal(channelRequestMsg{
  273. PeersID: 2,
  274. Request: "keepalive@openssh.com",
  275. WantReply: true,
  276. RequestSpecificData: []byte{},
  277. }))
  278. if err != nil {
  279. t.Fatalf("send: %v", err)
  280. }
  281. packet, err := serverPipe.readPacket()
  282. if err != nil {
  283. t.Fatalf("read packet: %v", err)
  284. }
  285. decoded, err := decode(packet)
  286. if err != nil {
  287. t.Fatalf("decode failed: %v", err)
  288. }
  289. switch msg := decoded.(type) {
  290. case *channelRequestFailureMsg:
  291. if msg.PeersID != 2 {
  292. t.Fatalf("received response to wrong message: %v", msg)
  293. }
  294. default:
  295. t.Fatalf("unexpected channel message: %v", msg)
  296. }
  297. kDone <- struct{}{}
  298. // Receive and respond to the keepalive to confirm the mux is
  299. // still processing requests.
  300. packet, err = serverPipe.readPacket()
  301. if err != nil {
  302. t.Fatalf("read packet: %v", err)
  303. }
  304. if packet[0] != msgGlobalRequest {
  305. t.Fatalf("expected global request")
  306. }
  307. err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
  308. Data: []byte{},
  309. }))
  310. if err != nil {
  311. t.Fatalf("failed to send failure msg: %v", err)
  312. }
  313. close(kDone)
  314. }()
  315. // Wait for the server to send the keepalive message and receive back a
  316. // response.
  317. select {
  318. case <-kDone:
  319. case <-time.After(10 * time.Second):
  320. t.Fatalf("server never received ack")
  321. }
  322. // Confirm client hasn't closed.
  323. if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
  324. t.Fatalf("failed to send keepalive: %v", err)
  325. }
  326. select {
  327. case <-kDone:
  328. case <-time.After(10 * time.Second):
  329. t.Fatalf("server never shut down")
  330. }
  331. }
  332. func TestMuxClosedChannel(t *testing.T) {
  333. clientPipe, serverPipe := memPipe()
  334. client := newMux(clientPipe)
  335. defer serverPipe.Close()
  336. defer client.Close()
  337. kDone := make(chan struct{})
  338. go func() {
  339. // Open the channel.
  340. packet, err := serverPipe.readPacket()
  341. if err != nil {
  342. t.Fatalf("read packet: %v", err)
  343. }
  344. if packet[0] != msgChannelOpen {
  345. t.Fatalf("expected chan open")
  346. }
  347. var openMsg channelOpenMsg
  348. if err := Unmarshal(packet, &openMsg); err != nil {
  349. t.Fatalf("unmarshal: %v", err)
  350. }
  351. // Send back the opened channel confirmation.
  352. err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{
  353. PeersID: openMsg.PeersID,
  354. MyID: 0,
  355. MyWindow: 0,
  356. MaxPacketSize: channelMaxPacket,
  357. }))
  358. if err != nil {
  359. t.Fatalf("send: %v", err)
  360. }
  361. // Close the channel.
  362. err = serverPipe.writePacket(Marshal(channelCloseMsg{
  363. PeersID: openMsg.PeersID,
  364. }))
  365. if err != nil {
  366. t.Fatalf("send: %v", err)
  367. }
  368. // Send a keepalive message on the channel we just closed.
  369. err = serverPipe.writePacket(Marshal(channelRequestMsg{
  370. PeersID: openMsg.PeersID,
  371. Request: "keepalive@openssh.com",
  372. WantReply: true,
  373. RequestSpecificData: []byte{},
  374. }))
  375. if err != nil {
  376. t.Fatalf("send: %v", err)
  377. }
  378. // Receive the channel closed response.
  379. packet, err = serverPipe.readPacket()
  380. if err != nil {
  381. t.Fatalf("read packet: %v", err)
  382. }
  383. if packet[0] != msgChannelClose {
  384. t.Fatalf("expected channel close")
  385. }
  386. // Receive the keepalive response failure.
  387. packet, err = serverPipe.readPacket()
  388. if err != nil {
  389. t.Fatalf("read packet: %v", err)
  390. }
  391. if packet[0] != msgChannelFailure {
  392. t.Fatalf("expected channel close")
  393. }
  394. kDone <- struct{}{}
  395. // Receive and respond to the keepalive to confirm the mux is
  396. // still processing requests.
  397. packet, err = serverPipe.readPacket()
  398. if err != nil {
  399. t.Fatalf("read packet: %v", err)
  400. }
  401. if packet[0] != msgGlobalRequest {
  402. t.Fatalf("expected global request")
  403. }
  404. err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
  405. Data: []byte{},
  406. }))
  407. if err != nil {
  408. t.Fatalf("failed to send failure msg: %v", err)
  409. }
  410. close(kDone)
  411. }()
  412. // Open a channel.
  413. ch, err := client.openChannel("chan", nil)
  414. if err != nil {
  415. t.Fatalf("OpenChannel: %v", err)
  416. }
  417. defer ch.Close()
  418. // Wait for the server to close the channel and send the keepalive.
  419. select {
  420. case <-kDone:
  421. case <-time.After(10 * time.Second):
  422. t.Fatalf("server never received ack")
  423. }
  424. // Make sure the channel closed.
  425. if _, ok := <-ch.incomingRequests; ok {
  426. t.Fatalf("channel not closed")
  427. }
  428. // Confirm client hasn't closed
  429. if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
  430. t.Fatalf("failed to send keepalive: %v", err)
  431. }
  432. select {
  433. case <-kDone:
  434. case <-time.After(10 * time.Second):
  435. t.Fatalf("server never shut down")
  436. }
  437. }
  438. func TestMuxGlobalRequest(t *testing.T) {
  439. clientMux, serverMux := muxPair()
  440. defer serverMux.Close()
  441. defer clientMux.Close()
  442. var seen bool
  443. go func() {
  444. for r := range serverMux.incomingRequests {
  445. seen = seen || r.Type == "peek"
  446. if r.WantReply {
  447. err := r.Reply(r.Type == "yes",
  448. append([]byte(r.Type), r.Payload...))
  449. if err != nil {
  450. t.Errorf("AckRequest: %v", err)
  451. }
  452. }
  453. }
  454. }()
  455. _, _, err := clientMux.SendRequest("peek", false, nil)
  456. if err != nil {
  457. t.Errorf("SendRequest: %v", err)
  458. }
  459. ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
  460. if !ok || string(data) != "yesa" || err != nil {
  461. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  462. ok, data, err)
  463. }
  464. if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
  465. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  466. ok, data, err)
  467. }
  468. if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
  469. t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
  470. ok, data, err)
  471. }
  472. if !seen {
  473. t.Errorf("never saw 'peek' request")
  474. }
  475. }
  476. func TestMuxGlobalRequestUnblock(t *testing.T) {
  477. clientMux, serverMux := muxPair()
  478. defer serverMux.Close()
  479. defer clientMux.Close()
  480. result := make(chan error, 1)
  481. go func() {
  482. _, _, err := clientMux.SendRequest("hello", true, nil)
  483. result <- err
  484. }()
  485. <-serverMux.incomingRequests
  486. serverMux.conn.Close()
  487. err := <-result
  488. if err != io.EOF {
  489. t.Errorf("want EOF, got %v", io.EOF)
  490. }
  491. }
  492. func TestMuxChannelRequestUnblock(t *testing.T) {
  493. a, b, connB := channelPair(t)
  494. defer a.Close()
  495. defer b.Close()
  496. defer connB.Close()
  497. result := make(chan error, 1)
  498. go func() {
  499. _, err := a.SendRequest("hello", true, nil)
  500. result <- err
  501. }()
  502. <-b.incomingRequests
  503. connB.conn.Close()
  504. err := <-result
  505. if err != io.EOF {
  506. t.Errorf("want EOF, got %v", err)
  507. }
  508. }
  509. func TestMuxCloseChannel(t *testing.T) {
  510. r, w, mux := channelPair(t)
  511. defer mux.Close()
  512. defer r.Close()
  513. defer w.Close()
  514. result := make(chan error, 1)
  515. go func() {
  516. var b [1024]byte
  517. _, err := r.Read(b[:])
  518. result <- err
  519. }()
  520. if err := w.Close(); err != nil {
  521. t.Errorf("w.Close: %v", err)
  522. }
  523. if _, err := w.Write([]byte("hello")); err != io.EOF {
  524. t.Errorf("got err %v, want io.EOF after Close", err)
  525. }
  526. if err := <-result; err != io.EOF {
  527. t.Errorf("got %v (%T), want io.EOF", err, err)
  528. }
  529. }
  530. func TestMuxCloseWriteChannel(t *testing.T) {
  531. r, w, mux := channelPair(t)
  532. defer mux.Close()
  533. result := make(chan error, 1)
  534. go func() {
  535. var b [1024]byte
  536. _, err := r.Read(b[:])
  537. result <- err
  538. }()
  539. if err := w.CloseWrite(); err != nil {
  540. t.Errorf("w.CloseWrite: %v", err)
  541. }
  542. if _, err := w.Write([]byte("hello")); err != io.EOF {
  543. t.Errorf("got err %v, want io.EOF after CloseWrite", err)
  544. }
  545. if err := <-result; err != io.EOF {
  546. t.Errorf("got %v (%T), want io.EOF", err, err)
  547. }
  548. }
  549. func TestMuxInvalidRecord(t *testing.T) {
  550. a, b := muxPair()
  551. defer a.Close()
  552. defer b.Close()
  553. packet := make([]byte, 1+4+4+1)
  554. packet[0] = msgChannelData
  555. marshalUint32(packet[1:], 29348723 /* invalid channel id */)
  556. marshalUint32(packet[5:], 1)
  557. packet[9] = 42
  558. a.conn.writePacket(packet)
  559. go a.SendRequest("hello", false, nil)
  560. // 'a' wrote an invalid packet, so 'b' has exited.
  561. req, ok := <-b.incomingRequests
  562. if ok {
  563. t.Errorf("got request %#v after receiving invalid packet", req)
  564. }
  565. }
  566. func TestZeroWindowAdjust(t *testing.T) {
  567. a, b, mux := channelPair(t)
  568. defer a.Close()
  569. defer b.Close()
  570. defer mux.Close()
  571. go func() {
  572. io.WriteString(a, "hello")
  573. // bogus adjust.
  574. a.sendMessage(windowAdjustMsg{})
  575. io.WriteString(a, "world")
  576. a.Close()
  577. }()
  578. want := "helloworld"
  579. c, _ := ioutil.ReadAll(b)
  580. if string(c) != want {
  581. t.Errorf("got %q want %q", c, want)
  582. }
  583. }
  584. func TestMuxMaxPacketSize(t *testing.T) {
  585. a, b, mux := channelPair(t)
  586. defer a.Close()
  587. defer b.Close()
  588. defer mux.Close()
  589. large := make([]byte, a.maxRemotePayload+1)
  590. packet := make([]byte, 1+4+4+1+len(large))
  591. packet[0] = msgChannelData
  592. marshalUint32(packet[1:], a.remoteId)
  593. marshalUint32(packet[5:], uint32(len(large)))
  594. packet[9] = 42
  595. if err := a.mux.conn.writePacket(packet); err != nil {
  596. t.Errorf("could not send packet")
  597. }
  598. go a.SendRequest("hello", false, nil)
  599. _, ok := <-b.incomingRequests
  600. if ok {
  601. t.Errorf("connection still alive after receiving large packet.")
  602. }
  603. }
  604. // Don't ship code with debug=true.
  605. func TestDebug(t *testing.T) {
  606. if debugMux {
  607. t.Error("mux debug switched on")
  608. }
  609. if debugHandshake {
  610. t.Error("handshake debug switched on")
  611. }
  612. if debugTransport {
  613. t.Error("transport debug switched on")
  614. }
  615. }