extensions.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. package mint
  2. import (
  3. "bytes"
  4. "fmt"
  5. "github.com/bifurcation/mint/syntax"
  6. )
  7. type ExtensionBody interface {
  8. Type() ExtensionType
  9. Marshal() ([]byte, error)
  10. Unmarshal(data []byte) (int, error)
  11. }
  12. // struct {
  13. // ExtensionType extension_type;
  14. // opaque extension_data<0..2^16-1>;
  15. // } Extension;
  16. type Extension struct {
  17. ExtensionType ExtensionType
  18. ExtensionData []byte `tls:"head=2"`
  19. }
  20. func (ext Extension) Marshal() ([]byte, error) {
  21. return syntax.Marshal(ext)
  22. }
  23. func (ext *Extension) Unmarshal(data []byte) (int, error) {
  24. return syntax.Unmarshal(data, ext)
  25. }
  26. type ExtensionList []Extension
  27. type extensionListInner struct {
  28. List []Extension `tls:"head=2"`
  29. }
  30. func (el ExtensionList) Marshal() ([]byte, error) {
  31. return syntax.Marshal(extensionListInner{el})
  32. }
  33. func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
  34. var list extensionListInner
  35. read, err := syntax.Unmarshal(data, &list)
  36. if err != nil {
  37. return 0, err
  38. }
  39. *el = list.List
  40. return read, nil
  41. }
  42. func (el *ExtensionList) Add(src ExtensionBody) error {
  43. data, err := src.Marshal()
  44. if err != nil {
  45. return err
  46. }
  47. if el == nil {
  48. el = new(ExtensionList)
  49. }
  50. // If one already exists with this type, replace it
  51. for i := range *el {
  52. if (*el)[i].ExtensionType == src.Type() {
  53. (*el)[i].ExtensionData = data
  54. return nil
  55. }
  56. }
  57. // Otherwise append
  58. *el = append(*el, Extension{
  59. ExtensionType: src.Type(),
  60. ExtensionData: data,
  61. })
  62. return nil
  63. }
  64. func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) {
  65. found := make(map[ExtensionType]bool)
  66. for _, dst := range dsts {
  67. for _, ext := range el {
  68. if ext.ExtensionType == dst.Type() {
  69. if found[dst.Type()] {
  70. return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
  71. }
  72. err := safeUnmarshal(dst, ext.ExtensionData)
  73. if err != nil {
  74. return nil, err
  75. }
  76. found[dst.Type()] = true
  77. }
  78. }
  79. }
  80. return found, nil
  81. }
  82. func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
  83. for _, ext := range el {
  84. if ext.ExtensionType == dst.Type() {
  85. err := safeUnmarshal(dst, ext.ExtensionData)
  86. if err != nil {
  87. return true, err
  88. }
  89. return true, nil
  90. }
  91. }
  92. return false, nil
  93. }
  94. // struct {
  95. // NameType name_type;
  96. // select (name_type) {
  97. // case host_name: HostName;
  98. // } name;
  99. // } ServerName;
  100. //
  101. // enum {
  102. // host_name(0), (255)
  103. // } NameType;
  104. //
  105. // opaque HostName<1..2^16-1>;
  106. //
  107. // struct {
  108. // ServerName server_name_list<1..2^16-1>
  109. // } ServerNameList;
  110. //
  111. // But we only care about the case where there's a single DNS hostname. We
  112. // will never create anything else, and throw if we receive something else
  113. //
  114. // 2 1 2
  115. // | listLen | NameType | nameLen | name |
  116. type ServerNameExtension string
  117. type serverNameInner struct {
  118. NameType uint8
  119. HostName []byte `tls:"head=2,min=1"`
  120. }
  121. type serverNameListInner struct {
  122. ServerNameList []serverNameInner `tls:"head=2,min=1"`
  123. }
  124. func (sni ServerNameExtension) Type() ExtensionType {
  125. return ExtensionTypeServerName
  126. }
  127. func (sni ServerNameExtension) Marshal() ([]byte, error) {
  128. list := serverNameListInner{
  129. ServerNameList: []serverNameInner{{
  130. NameType: 0x00, // host_name
  131. HostName: []byte(sni),
  132. }},
  133. }
  134. return syntax.Marshal(list)
  135. }
  136. func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
  137. var list serverNameListInner
  138. read, err := syntax.Unmarshal(data, &list)
  139. if err != nil {
  140. return 0, err
  141. }
  142. // Syntax requires at least one entry
  143. // Entries beyond the first are ignored
  144. if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
  145. return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
  146. }
  147. *sni = ServerNameExtension(list.ServerNameList[0].HostName)
  148. return read, nil
  149. }
  150. // struct {
  151. // NamedGroup group;
  152. // opaque key_exchange<1..2^16-1>;
  153. // } KeyShareEntry;
  154. //
  155. // struct {
  156. // select (Handshake.msg_type) {
  157. // case client_hello:
  158. // KeyShareEntry client_shares<0..2^16-1>;
  159. //
  160. // case hello_retry_request:
  161. // NamedGroup selected_group;
  162. //
  163. // case server_hello:
  164. // KeyShareEntry server_share;
  165. // };
  166. // } KeyShare;
  167. type KeyShareEntry struct {
  168. Group NamedGroup
  169. KeyExchange []byte `tls:"head=2,min=1"`
  170. }
  171. func (kse KeyShareEntry) SizeValid() bool {
  172. return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
  173. }
  174. type KeyShareExtension struct {
  175. HandshakeType HandshakeType
  176. SelectedGroup NamedGroup
  177. Shares []KeyShareEntry
  178. }
  179. type KeyShareClientHelloInner struct {
  180. ClientShares []KeyShareEntry `tls:"head=2,min=0"`
  181. }
  182. type KeyShareHelloRetryInner struct {
  183. SelectedGroup NamedGroup
  184. }
  185. type KeyShareServerHelloInner struct {
  186. ServerShare KeyShareEntry
  187. }
  188. func (ks KeyShareExtension) Type() ExtensionType {
  189. return ExtensionTypeKeyShare
  190. }
  191. func (ks KeyShareExtension) Marshal() ([]byte, error) {
  192. switch ks.HandshakeType {
  193. case HandshakeTypeClientHello:
  194. for _, share := range ks.Shares {
  195. if !share.SizeValid() {
  196. return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
  197. }
  198. }
  199. return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
  200. case HandshakeTypeHelloRetryRequest:
  201. if len(ks.Shares) > 0 {
  202. return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
  203. }
  204. return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
  205. case HandshakeTypeServerHello:
  206. if len(ks.Shares) != 1 {
  207. return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
  208. }
  209. if !ks.Shares[0].SizeValid() {
  210. return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
  211. }
  212. return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
  213. default:
  214. return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
  215. }
  216. }
  217. func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
  218. switch ks.HandshakeType {
  219. case HandshakeTypeClientHello:
  220. var inner KeyShareClientHelloInner
  221. read, err := syntax.Unmarshal(data, &inner)
  222. if err != nil {
  223. return 0, err
  224. }
  225. for _, share := range inner.ClientShares {
  226. if !share.SizeValid() {
  227. return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
  228. }
  229. }
  230. ks.Shares = inner.ClientShares
  231. return read, nil
  232. case HandshakeTypeHelloRetryRequest:
  233. var inner KeyShareHelloRetryInner
  234. read, err := syntax.Unmarshal(data, &inner)
  235. if err != nil {
  236. return 0, err
  237. }
  238. ks.SelectedGroup = inner.SelectedGroup
  239. return read, nil
  240. case HandshakeTypeServerHello:
  241. var inner KeyShareServerHelloInner
  242. read, err := syntax.Unmarshal(data, &inner)
  243. if err != nil {
  244. return 0, err
  245. }
  246. if !inner.ServerShare.SizeValid() {
  247. return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
  248. }
  249. ks.Shares = []KeyShareEntry{inner.ServerShare}
  250. return read, nil
  251. default:
  252. return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
  253. }
  254. }
  255. // struct {
  256. // NamedGroup named_group_list<2..2^16-1>;
  257. // } NamedGroupList;
  258. type SupportedGroupsExtension struct {
  259. Groups []NamedGroup `tls:"head=2,min=2"`
  260. }
  261. func (sg SupportedGroupsExtension) Type() ExtensionType {
  262. return ExtensionTypeSupportedGroups
  263. }
  264. func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
  265. return syntax.Marshal(sg)
  266. }
  267. func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
  268. return syntax.Unmarshal(data, sg)
  269. }
  270. // struct {
  271. // SignatureScheme supported_signature_algorithms<2..2^16-2>;
  272. // } SignatureSchemeList
  273. type SignatureAlgorithmsExtension struct {
  274. Algorithms []SignatureScheme `tls:"head=2,min=2"`
  275. }
  276. func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
  277. return ExtensionTypeSignatureAlgorithms
  278. }
  279. func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
  280. return syntax.Marshal(sa)
  281. }
  282. func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
  283. return syntax.Unmarshal(data, sa)
  284. }
  285. // struct {
  286. // opaque identity<1..2^16-1>;
  287. // uint32 obfuscated_ticket_age;
  288. // } PskIdentity;
  289. //
  290. // opaque PskBinderEntry<32..255>;
  291. //
  292. // struct {
  293. // select (Handshake.msg_type) {
  294. // case client_hello:
  295. // PskIdentity identities<7..2^16-1>;
  296. // PskBinderEntry binders<33..2^16-1>;
  297. //
  298. // case server_hello:
  299. // uint16 selected_identity;
  300. // };
  301. //
  302. // } PreSharedKeyExtension;
  303. type PSKIdentity struct {
  304. Identity []byte `tls:"head=2,min=1"`
  305. ObfuscatedTicketAge uint32
  306. }
  307. type PSKBinderEntry struct {
  308. Binder []byte `tls:"head=1,min=32"`
  309. }
  310. type PreSharedKeyExtension struct {
  311. HandshakeType HandshakeType
  312. Identities []PSKIdentity
  313. Binders []PSKBinderEntry
  314. SelectedIdentity uint16
  315. }
  316. type preSharedKeyClientInner struct {
  317. Identities []PSKIdentity `tls:"head=2,min=7"`
  318. Binders []PSKBinderEntry `tls:"head=2,min=33"`
  319. }
  320. type preSharedKeyServerInner struct {
  321. SelectedIdentity uint16
  322. }
  323. func (psk PreSharedKeyExtension) Type() ExtensionType {
  324. return ExtensionTypePreSharedKey
  325. }
  326. func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
  327. switch psk.HandshakeType {
  328. case HandshakeTypeClientHello:
  329. return syntax.Marshal(preSharedKeyClientInner{
  330. Identities: psk.Identities,
  331. Binders: psk.Binders,
  332. })
  333. case HandshakeTypeServerHello:
  334. if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
  335. return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
  336. }
  337. return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
  338. default:
  339. return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
  340. }
  341. }
  342. func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
  343. switch psk.HandshakeType {
  344. case HandshakeTypeClientHello:
  345. var inner preSharedKeyClientInner
  346. read, err := syntax.Unmarshal(data, &inner)
  347. if err != nil {
  348. return 0, err
  349. }
  350. if len(inner.Identities) != len(inner.Binders) {
  351. return 0, fmt.Errorf("Lengths of identities and binders not equal")
  352. }
  353. psk.Identities = inner.Identities
  354. psk.Binders = inner.Binders
  355. return read, nil
  356. case HandshakeTypeServerHello:
  357. var inner preSharedKeyServerInner
  358. read, err := syntax.Unmarshal(data, &inner)
  359. if err != nil {
  360. return 0, err
  361. }
  362. psk.SelectedIdentity = inner.SelectedIdentity
  363. return read, nil
  364. default:
  365. return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
  366. }
  367. }
  368. func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
  369. for i, localID := range psk.Identities {
  370. if bytes.Equal(localID.Identity, id) {
  371. return psk.Binders[i].Binder, true
  372. }
  373. }
  374. return nil, false
  375. }
  376. // enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
  377. //
  378. // struct {
  379. // PskKeyExchangeMode ke_modes<1..255>;
  380. // } PskKeyExchangeModes;
  381. type PSKKeyExchangeModesExtension struct {
  382. KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
  383. }
  384. func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
  385. return ExtensionTypePSKKeyExchangeModes
  386. }
  387. func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
  388. return syntax.Marshal(pkem)
  389. }
  390. func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
  391. return syntax.Unmarshal(data, pkem)
  392. }
  393. // struct {
  394. // } EarlyDataIndication;
  395. type EarlyDataExtension struct{}
  396. func (ed EarlyDataExtension) Type() ExtensionType {
  397. return ExtensionTypeEarlyData
  398. }
  399. func (ed EarlyDataExtension) Marshal() ([]byte, error) {
  400. return []byte{}, nil
  401. }
  402. func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
  403. return 0, nil
  404. }
  405. // struct {
  406. // uint32 max_early_data_size;
  407. // } TicketEarlyDataInfo;
  408. type TicketEarlyDataInfoExtension struct {
  409. MaxEarlyDataSize uint32
  410. }
  411. func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
  412. return ExtensionTypeTicketEarlyDataInfo
  413. }
  414. func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
  415. return syntax.Marshal(tedi)
  416. }
  417. func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
  418. return syntax.Unmarshal(data, tedi)
  419. }
  420. // opaque ProtocolName<1..2^8-1>;
  421. //
  422. // struct {
  423. // ProtocolName protocol_name_list<2..2^16-1>
  424. // } ProtocolNameList;
  425. type ALPNExtension struct {
  426. Protocols []string
  427. }
  428. type protocolNameInner struct {
  429. Name []byte `tls:"head=1,min=1"`
  430. }
  431. type alpnExtensionInner struct {
  432. Protocols []protocolNameInner `tls:"head=2,min=2"`
  433. }
  434. func (alpn ALPNExtension) Type() ExtensionType {
  435. return ExtensionTypeALPN
  436. }
  437. func (alpn ALPNExtension) Marshal() ([]byte, error) {
  438. protocols := make([]protocolNameInner, len(alpn.Protocols))
  439. for i, protocol := range alpn.Protocols {
  440. protocols[i] = protocolNameInner{[]byte(protocol)}
  441. }
  442. return syntax.Marshal(alpnExtensionInner{protocols})
  443. }
  444. func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
  445. var inner alpnExtensionInner
  446. read, err := syntax.Unmarshal(data, &inner)
  447. if err != nil {
  448. return 0, err
  449. }
  450. alpn.Protocols = make([]string, len(inner.Protocols))
  451. for i, protocol := range inner.Protocols {
  452. alpn.Protocols[i] = string(protocol.Name)
  453. }
  454. return read, nil
  455. }
  456. // struct {
  457. // ProtocolVersion versions<2..254>;
  458. // } SupportedVersions;
  459. type SupportedVersionsExtension struct {
  460. HandshakeType HandshakeType
  461. Versions []uint16
  462. }
  463. type SupportedVersionsClientHelloInner struct {
  464. Versions []uint16 `tls:"head=1,min=2,max=254"`
  465. }
  466. type SupportedVersionsServerHelloInner struct {
  467. Version uint16
  468. }
  469. func (sv SupportedVersionsExtension) Type() ExtensionType {
  470. return ExtensionTypeSupportedVersions
  471. }
  472. func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
  473. switch sv.HandshakeType {
  474. case HandshakeTypeClientHello:
  475. return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions})
  476. case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
  477. return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]})
  478. default:
  479. return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
  480. }
  481. }
  482. func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
  483. switch sv.HandshakeType {
  484. case HandshakeTypeClientHello:
  485. var inner SupportedVersionsClientHelloInner
  486. read, err := syntax.Unmarshal(data, &inner)
  487. if err != nil {
  488. return 0, err
  489. }
  490. sv.Versions = inner.Versions
  491. return read, nil
  492. case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
  493. var inner SupportedVersionsServerHelloInner
  494. read, err := syntax.Unmarshal(data, &inner)
  495. if err != nil {
  496. return 0, err
  497. }
  498. sv.Versions = []uint16{inner.Version}
  499. return read, nil
  500. default:
  501. return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
  502. }
  503. }
  504. // struct {
  505. // opaque cookie<1..2^16-1>;
  506. // } Cookie;
  507. type CookieExtension struct {
  508. Cookie []byte `tls:"head=2,min=1"`
  509. }
  510. func (c CookieExtension) Type() ExtensionType {
  511. return ExtensionTypeCookie
  512. }
  513. func (c CookieExtension) Marshal() ([]byte, error) {
  514. return syntax.Marshal(c)
  515. }
  516. func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
  517. return syntax.Unmarshal(data, c)
  518. }