server_test.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package agent
  5. import (
  6. "crypto"
  7. "crypto/rand"
  8. "fmt"
  9. pseudorand "math/rand"
  10. "reflect"
  11. "strings"
  12. "testing"
  13. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
  14. )
  15. func TestServer(t *testing.T) {
  16. c1, c2, err := netPipe()
  17. if err != nil {
  18. t.Fatalf("netPipe: %v", err)
  19. }
  20. defer c1.Close()
  21. defer c2.Close()
  22. client := NewClient(c1)
  23. go ServeAgent(NewKeyring(), c2)
  24. testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
  25. }
  26. func TestLockServer(t *testing.T) {
  27. testLockAgent(NewKeyring(), t)
  28. }
  29. func TestSetupForwardAgent(t *testing.T) {
  30. a, b, err := netPipe()
  31. if err != nil {
  32. t.Fatalf("netPipe: %v", err)
  33. }
  34. defer a.Close()
  35. defer b.Close()
  36. _, socket, cleanup := startOpenSSHAgent(t)
  37. defer cleanup()
  38. serverConf := ssh.ServerConfig{
  39. NoClientAuth: true,
  40. }
  41. serverConf.AddHostKey(testSigners["rsa"])
  42. incoming := make(chan *ssh.ServerConn, 1)
  43. go func() {
  44. conn, _, _, err := ssh.NewServerConn(a, &serverConf)
  45. incoming <- conn
  46. if err != nil {
  47. t.Errorf("NewServerConn error: %v", err)
  48. return
  49. }
  50. }()
  51. conf := ssh.ClientConfig{
  52. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  53. }
  54. conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
  55. if err != nil {
  56. t.Fatalf("NewClientConn: %v", err)
  57. }
  58. client := ssh.NewClient(conn, chans, reqs)
  59. if err := ForwardToRemote(client, socket); err != nil {
  60. t.Fatalf("SetupForwardAgent: %v", err)
  61. }
  62. server := <-incoming
  63. if server == nil {
  64. t.Fatal("Unable to get server")
  65. }
  66. ch, reqs, err := server.OpenChannel(channelType, nil)
  67. if err != nil {
  68. t.Fatalf("OpenChannel(%q): %v", channelType, err)
  69. }
  70. go ssh.DiscardRequests(reqs)
  71. agentClient := NewClient(ch)
  72. testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
  73. conn.Close()
  74. }
  75. func TestV1ProtocolMessages(t *testing.T) {
  76. c1, c2, err := netPipe()
  77. if err != nil {
  78. t.Fatalf("netPipe: %v", err)
  79. }
  80. defer c1.Close()
  81. defer c2.Close()
  82. c := NewClient(c1)
  83. go ServeAgent(NewKeyring(), c2)
  84. testV1ProtocolMessages(t, c.(*client))
  85. }
  86. func testV1ProtocolMessages(t *testing.T, c *client) {
  87. reply, err := c.call([]byte{agentRequestV1Identities})
  88. if err != nil {
  89. t.Fatalf("v1 request all failed: %v", err)
  90. }
  91. if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
  92. t.Fatalf("invalid request all response: %#v", reply)
  93. }
  94. reply, err = c.call([]byte{agentRemoveAllV1Identities})
  95. if err != nil {
  96. t.Fatalf("v1 remove all failed: %v", err)
  97. }
  98. if _, ok := reply.(*successAgentMsg); !ok {
  99. t.Fatalf("invalid remove all response: %#v", reply)
  100. }
  101. }
  102. func verifyKey(sshAgent Agent) error {
  103. keys, err := sshAgent.List()
  104. if err != nil {
  105. return fmt.Errorf("listing keys: %v", err)
  106. }
  107. if len(keys) != 1 {
  108. return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
  109. }
  110. buf := make([]byte, 128)
  111. if _, err := rand.Read(buf); err != nil {
  112. return fmt.Errorf("rand: %v", err)
  113. }
  114. sig, err := sshAgent.Sign(keys[0], buf)
  115. if err != nil {
  116. return fmt.Errorf("sign: %v", err)
  117. }
  118. if err := keys[0].Verify(buf, sig); err != nil {
  119. return fmt.Errorf("verify: %v", err)
  120. }
  121. return nil
  122. }
  123. func addKeyToAgent(key crypto.PrivateKey) error {
  124. sshAgent := NewKeyring()
  125. if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
  126. return fmt.Errorf("add: %v", err)
  127. }
  128. return verifyKey(sshAgent)
  129. }
  130. func TestKeyTypes(t *testing.T) {
  131. for k, v := range testPrivateKeys {
  132. if err := addKeyToAgent(v); err != nil {
  133. t.Errorf("error adding key type %s, %v", k, err)
  134. }
  135. if err := addCertToAgentSock(v, nil); err != nil {
  136. t.Errorf("error adding key type %s, %v", k, err)
  137. }
  138. }
  139. }
  140. func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
  141. a, b, err := netPipe()
  142. if err != nil {
  143. return err
  144. }
  145. agentServer := NewKeyring()
  146. go ServeAgent(agentServer, a)
  147. agentClient := NewClient(b)
  148. if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
  149. return fmt.Errorf("add: %v", err)
  150. }
  151. return verifyKey(agentClient)
  152. }
  153. func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
  154. sshAgent := NewKeyring()
  155. if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
  156. return fmt.Errorf("add: %v", err)
  157. }
  158. return verifyKey(sshAgent)
  159. }
  160. func TestCertTypes(t *testing.T) {
  161. for keyType, key := range testPublicKeys {
  162. cert := &ssh.Certificate{
  163. ValidPrincipals: []string{"gopher1"},
  164. ValidAfter: 0,
  165. ValidBefore: ssh.CertTimeInfinity,
  166. Key: key,
  167. Serial: 1,
  168. CertType: ssh.UserCert,
  169. SignatureKey: testPublicKeys["rsa"],
  170. Permissions: ssh.Permissions{
  171. CriticalOptions: map[string]string{},
  172. Extensions: map[string]string{},
  173. },
  174. }
  175. if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
  176. t.Fatalf("signcert: %v", err)
  177. }
  178. if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
  179. t.Fatalf("%v", err)
  180. }
  181. if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
  182. t.Fatalf("%v", err)
  183. }
  184. }
  185. }
  186. func TestParseConstraints(t *testing.T) {
  187. // Test LifetimeSecs
  188. var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
  189. lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
  190. if err != nil {
  191. t.Fatalf("parseConstraints: %v", err)
  192. }
  193. if lifetimeSecs != msg.LifetimeSecs {
  194. t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
  195. }
  196. // Test ConfirmBeforeUse
  197. _, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
  198. if err != nil {
  199. t.Fatalf("%v", err)
  200. }
  201. if !confirmBeforeUse {
  202. t.Error("got comfirmBeforeUse == false")
  203. }
  204. // Test ConstraintExtensions
  205. var data []byte
  206. var expect []ConstraintExtension
  207. for i := 0; i < 10; i++ {
  208. var ext = ConstraintExtension{
  209. ExtensionName: fmt.Sprintf("name%d", i),
  210. ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
  211. }
  212. expect = append(expect, ext)
  213. if i%2 == 0 {
  214. data = append(data, agentConstrainExtension)
  215. } else {
  216. data = append(data, agentConstrainExtensionV00)
  217. }
  218. data = append(data, ssh.Marshal(ext)...)
  219. }
  220. _, _, extensions, err := parseConstraints(data)
  221. if err != nil {
  222. t.Fatalf("%v", err)
  223. }
  224. if !reflect.DeepEqual(expect, extensions) {
  225. t.Errorf("got extension %v, want %v", extensions, expect)
  226. }
  227. // Test Unknown Constraint
  228. _, _, _, err = parseConstraints([]byte{128})
  229. if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
  230. t.Errorf("unexpected error: %v", err)
  231. }
  232. }