server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. // Copyright 2023 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bytes"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "net"
  11. "reflect"
  12. "strings"
  13. "sync/atomic"
  14. "testing"
  15. "time"
  16. )
  17. func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
  18. for _, tt := range []struct {
  19. name string
  20. key Signer
  21. wantError bool
  22. }{
  23. {"rsa", testSigners["rsa"], false},
  24. {"dsa", testSigners["dsa"], true},
  25. {"ed25519", testSigners["ed25519"], true},
  26. } {
  27. c1, c2, err := netPipe()
  28. if err != nil {
  29. t.Fatalf("netPipe: %v", err)
  30. }
  31. defer c1.Close()
  32. defer c2.Close()
  33. serverConf := &ServerConfig{
  34. PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
  35. PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
  36. return nil, nil
  37. },
  38. }
  39. serverConf.AddHostKey(testSigners["ecdsap256"])
  40. done := make(chan struct{})
  41. go func() {
  42. defer close(done)
  43. NewServerConn(c1, serverConf)
  44. }()
  45. clientConf := ClientConfig{
  46. User: "user",
  47. Auth: []AuthMethod{
  48. PublicKeys(tt.key),
  49. },
  50. HostKeyCallback: InsecureIgnoreHostKey(),
  51. }
  52. _, _, _, err = NewClientConn(c2, "", &clientConf)
  53. if err != nil {
  54. if !tt.wantError {
  55. t.Errorf("%s: got unexpected error %q", tt.name, err.Error())
  56. }
  57. } else if tt.wantError {
  58. t.Errorf("%s: succeeded, but want error", tt.name)
  59. }
  60. <-done
  61. }
  62. }
  63. func TestMaxAuthTriesNoneMethod(t *testing.T) {
  64. username := "testuser"
  65. serverConfig := &ServerConfig{
  66. MaxAuthTries: 2,
  67. PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
  68. if conn.User() == username && string(password) == clientPassword {
  69. return nil, nil
  70. }
  71. return nil, errors.New("invalid credentials")
  72. },
  73. }
  74. c1, c2, err := netPipe()
  75. if err != nil {
  76. t.Fatalf("netPipe: %v", err)
  77. }
  78. defer c1.Close()
  79. defer c2.Close()
  80. var serverAuthErrors []error
  81. serverConfig.AddHostKey(testSigners["rsa"])
  82. serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
  83. serverAuthErrors = append(serverAuthErrors, err)
  84. }
  85. go newServer(c1, serverConfig)
  86. clientConfig := ClientConfig{
  87. User: username,
  88. HostKeyCallback: InsecureIgnoreHostKey(),
  89. }
  90. clientConfig.SetDefaults()
  91. // Our client will send 'none' auth only once, so we need to send the
  92. // requests manually.
  93. c := &connection{
  94. sshConn: sshConn{
  95. conn: c2,
  96. user: username,
  97. clientVersion: []byte(packageVersion),
  98. },
  99. }
  100. c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
  101. if err != nil {
  102. t.Fatalf("unable to exchange version: %v", err)
  103. }
  104. c.transport = newClientTransport(
  105. newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */),
  106. c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr())
  107. if err := c.transport.waitSession(); err != nil {
  108. t.Fatalf("unable to wait session: %v", err)
  109. }
  110. c.sessionID = c.transport.getSessionID()
  111. if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
  112. t.Fatalf("unable to send ssh-userauth message: %v", err)
  113. }
  114. packet, err := c.transport.readPacket()
  115. if err != nil {
  116. t.Fatal(err)
  117. }
  118. if len(packet) > 0 && packet[0] == msgExtInfo {
  119. packet, err = c.transport.readPacket()
  120. if err != nil {
  121. t.Fatal(err)
  122. }
  123. }
  124. var serviceAccept serviceAcceptMsg
  125. if err := Unmarshal(packet, &serviceAccept); err != nil {
  126. t.Fatal(err)
  127. }
  128. for i := 0; i <= serverConfig.MaxAuthTries; i++ {
  129. auth := new(noneAuth)
  130. _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil)
  131. if i < serverConfig.MaxAuthTries {
  132. if err != nil {
  133. t.Fatal(err)
  134. }
  135. continue
  136. }
  137. if err == nil {
  138. t.Fatal("client: got no error")
  139. } else if !strings.Contains(err.Error(), "too many authentication failures") {
  140. t.Fatalf("client: got unexpected error: %v", err)
  141. }
  142. }
  143. if len(serverAuthErrors) != 3 {
  144. t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
  145. }
  146. for _, err := range serverAuthErrors {
  147. if !errors.Is(err, ErrNoAuth) {
  148. t.Errorf("go error: %v; want: %v", err, ErrNoAuth)
  149. }
  150. }
  151. }
  152. func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) {
  153. username := "testuser"
  154. serverConfig := &ServerConfig{
  155. MaxAuthTries: 1,
  156. PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
  157. if conn.User() == username && string(password) == clientPassword {
  158. return nil, nil
  159. }
  160. return nil, errors.New("invalid credentials")
  161. },
  162. }
  163. clientConfig := &ClientConfig{
  164. User: username,
  165. Auth: []AuthMethod{
  166. Password(clientPassword),
  167. },
  168. HostKeyCallback: InsecureIgnoreHostKey(),
  169. }
  170. serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
  171. if err != nil {
  172. t.Fatalf("client login error: %s", err)
  173. }
  174. if len(serverAuthErrors) != 2 {
  175. t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
  176. }
  177. if !errors.Is(serverAuthErrors[0], ErrNoAuth) {
  178. t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth)
  179. }
  180. if serverAuthErrors[1] != nil {
  181. t.Errorf("unexpected error: %v", serverAuthErrors[1])
  182. }
  183. }
  184. func TestNewServerConnValidationErrors(t *testing.T) {
  185. serverConf := &ServerConfig{
  186. PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
  187. }
  188. c := &markerConn{}
  189. _, _, _, err := NewServerConn(c, serverConf)
  190. if err == nil {
  191. t.Fatal("NewServerConn with invalid public key auth algorithms succeeded")
  192. }
  193. if !c.isClosed() {
  194. t.Fatal("NewServerConn with invalid public key auth algorithms left connection open")
  195. }
  196. if c.isUsed() {
  197. t.Fatal("NewServerConn with invalid public key auth algorithms used connection")
  198. }
  199. serverConf = &ServerConfig{
  200. Config: Config{
  201. KeyExchanges: []string{kexAlgoDHGEXSHA256},
  202. },
  203. }
  204. c = &markerConn{}
  205. _, _, _, err = NewServerConn(c, serverConf)
  206. if err == nil {
  207. t.Fatal("NewServerConn with unsupported key exchange succeeded")
  208. }
  209. if !c.isClosed() {
  210. t.Fatal("NewServerConn with unsupported key exchange left connection open")
  211. }
  212. if c.isUsed() {
  213. t.Fatal("NewServerConn with unsupported key exchange used connection")
  214. }
  215. }
  216. func TestBannerError(t *testing.T) {
  217. serverConfig := &ServerConfig{
  218. BannerCallback: func(ConnMetadata) string {
  219. return "banner from BannerCallback"
  220. },
  221. NoClientAuth: true,
  222. NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
  223. err := &BannerError{
  224. Err: errors.New("error from NoClientAuthCallback"),
  225. Message: "banner from NoClientAuthCallback",
  226. }
  227. return nil, fmt.Errorf("wrapped: %w", err)
  228. },
  229. PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
  230. return &Permissions{}, nil
  231. },
  232. PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
  233. return nil, &BannerError{
  234. Err: errors.New("error from PublicKeyCallback"),
  235. Message: "banner from PublicKeyCallback",
  236. }
  237. },
  238. KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
  239. return nil, &BannerError{
  240. Err: nil, // make sure that a nil inner error is allowed
  241. Message: "banner from KeyboardInteractiveCallback",
  242. }
  243. },
  244. }
  245. serverConfig.AddHostKey(testSigners["rsa"])
  246. var banners []string
  247. clientConfig := &ClientConfig{
  248. User: "test",
  249. Auth: []AuthMethod{
  250. PublicKeys(testSigners["rsa"]),
  251. KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
  252. return []string{"letmein"}, nil
  253. }),
  254. Password(clientPassword),
  255. },
  256. HostKeyCallback: InsecureIgnoreHostKey(),
  257. BannerCallback: func(msg string) error {
  258. banners = append(banners, msg)
  259. return nil
  260. },
  261. }
  262. c1, c2, err := netPipe()
  263. if err != nil {
  264. t.Fatalf("netPipe: %v", err)
  265. }
  266. defer c1.Close()
  267. defer c2.Close()
  268. go newServer(c1, serverConfig)
  269. c, _, _, err := NewClientConn(c2, "", clientConfig)
  270. if err != nil {
  271. t.Fatalf("client connection failed: %v", err)
  272. }
  273. defer c.Close()
  274. wantBanners := []string{
  275. "banner from BannerCallback",
  276. "banner from NoClientAuthCallback",
  277. "banner from PublicKeyCallback",
  278. "banner from KeyboardInteractiveCallback",
  279. }
  280. if !reflect.DeepEqual(banners, wantBanners) {
  281. t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
  282. }
  283. }
  284. func TestPublicKeyCallbackLastSeen(t *testing.T) {
  285. var lastSeenKey PublicKey
  286. c1, c2, err := netPipe()
  287. if err != nil {
  288. t.Fatalf("netPipe: %v", err)
  289. }
  290. defer c1.Close()
  291. defer c2.Close()
  292. serverConf := &ServerConfig{
  293. PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
  294. lastSeenKey = key
  295. fmt.Printf("seen %#v\n", key)
  296. if _, ok := key.(*dsaPublicKey); !ok {
  297. return nil, errors.New("nope")
  298. }
  299. return nil, nil
  300. },
  301. }
  302. serverConf.AddHostKey(testSigners["ecdsap256"])
  303. done := make(chan struct{})
  304. go func() {
  305. defer close(done)
  306. NewServerConn(c1, serverConf)
  307. }()
  308. clientConf := ClientConfig{
  309. User: "user",
  310. Auth: []AuthMethod{
  311. PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]),
  312. },
  313. HostKeyCallback: InsecureIgnoreHostKey(),
  314. }
  315. _, _, _, err = NewClientConn(c2, "", &clientConf)
  316. if err != nil {
  317. t.Fatal(err)
  318. }
  319. <-done
  320. expectedPublicKey := testSigners["dsa"].PublicKey().Marshal()
  321. lastSeenMarshalled := lastSeenKey.Marshal()
  322. if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) {
  323. t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey())
  324. }
  325. }
  326. func TestPreAuthConnAndBanners(t *testing.T) {
  327. testDone := make(chan struct{})
  328. defer close(testDone)
  329. authConnc := make(chan ServerPreAuthConn, 1)
  330. serverConfig := &ServerConfig{
  331. PreAuthConnCallback: func(c ServerPreAuthConn) {
  332. t.Logf("got ServerPreAuthConn: %v", c)
  333. authConnc <- c // for use later in the test
  334. for _, s := range []string{"hello1", "hello2"} {
  335. if err := c.SendAuthBanner(s); err != nil {
  336. t.Errorf("failed to send banner %q: %v", s, err)
  337. }
  338. }
  339. // Now start a goroutine to spam SendAuthBanner in hopes
  340. // of hitting a race.
  341. go func() {
  342. for {
  343. select {
  344. case <-testDone:
  345. return
  346. default:
  347. if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase {
  348. t.Errorf("unexpected error from SendAuthBanner: %v", err)
  349. }
  350. time.Sleep(5 * time.Millisecond)
  351. }
  352. }
  353. }()
  354. },
  355. NoClientAuth: true,
  356. NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
  357. t.Logf("got NoClientAuthCallback")
  358. return &Permissions{}, nil
  359. },
  360. }
  361. serverConfig.AddHostKey(testSigners["rsa"])
  362. var banners []string
  363. clientConfig := &ClientConfig{
  364. User: "test",
  365. HostKeyCallback: InsecureIgnoreHostKey(),
  366. BannerCallback: func(msg string) error {
  367. if msg != "attempted-race" {
  368. banners = append(banners, msg)
  369. }
  370. return nil
  371. },
  372. }
  373. c1, c2, err := netPipe()
  374. if err != nil {
  375. t.Fatalf("netPipe: %v", err)
  376. }
  377. defer c1.Close()
  378. defer c2.Close()
  379. go newServer(c1, serverConfig)
  380. c, _, _, err := NewClientConn(c2, "", clientConfig)
  381. if err != nil {
  382. t.Fatalf("client connection failed: %v", err)
  383. }
  384. defer c.Close()
  385. wantBanners := []string{
  386. "hello1",
  387. "hello2",
  388. }
  389. if !reflect.DeepEqual(banners, wantBanners) {
  390. t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
  391. }
  392. // Now that we're authenticated, verify that use of SendBanner
  393. // is an error.
  394. var bc ServerPreAuthConn
  395. select {
  396. case bc = <-authConnc:
  397. default:
  398. t.Fatal("expected ServerPreAuthConn")
  399. }
  400. if err := bc.SendAuthBanner("wrong-phase"); err == nil {
  401. t.Error("unexpected success of SendAuthBanner after authentication")
  402. } else if err != errSendBannerPhase {
  403. t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase)
  404. }
  405. }
  406. type markerConn struct {
  407. closed uint32
  408. used uint32
  409. }
  410. func (c *markerConn) isClosed() bool {
  411. return atomic.LoadUint32(&c.closed) != 0
  412. }
  413. func (c *markerConn) isUsed() bool {
  414. return atomic.LoadUint32(&c.used) != 0
  415. }
  416. func (c *markerConn) Close() error {
  417. atomic.StoreUint32(&c.closed, 1)
  418. return nil
  419. }
  420. func (c *markerConn) Read(b []byte) (n int, err error) {
  421. atomic.StoreUint32(&c.used, 1)
  422. if atomic.LoadUint32(&c.closed) != 0 {
  423. return 0, net.ErrClosed
  424. } else {
  425. return 0, io.EOF
  426. }
  427. }
  428. func (c *markerConn) Write(b []byte) (n int, err error) {
  429. atomic.StoreUint32(&c.used, 1)
  430. if atomic.LoadUint32(&c.closed) != 0 {
  431. return 0, net.ErrClosed
  432. } else {
  433. return 0, io.ErrClosedPipe
  434. }
  435. }
  436. func (*markerConn) LocalAddr() net.Addr { return nil }
  437. func (*markerConn) RemoteAddr() net.Addr { return nil }
  438. func (*markerConn) SetDeadline(t time.Time) error { return nil }
  439. func (*markerConn) SetReadDeadline(t time.Time) error { return nil }
  440. func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }