nftables_runner.go 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux
  4. package linuxfw
  5. import (
  6. "encoding/binary"
  7. "encoding/hex"
  8. "errors"
  9. "fmt"
  10. "net"
  11. "net/netip"
  12. "reflect"
  13. "strings"
  14. "github.com/google/nftables"
  15. "github.com/google/nftables/expr"
  16. "golang.org/x/sys/unix"
  17. "tailscale.com/net/tsaddr"
  18. "tailscale.com/types/logger"
  19. "tailscale.com/types/ptr"
  20. )
  21. const (
  22. chainNameForward = "ts-forward"
  23. chainNameInput = "ts-input"
  24. chainNamePostrouting = "ts-postrouting"
  25. )
  26. // chainTypeRegular is an nftables chain that does not apply to a hook.
  27. const chainTypeRegular = ""
  28. type chainInfo struct {
  29. table *nftables.Table
  30. name string
  31. chainType nftables.ChainType
  32. chainHook *nftables.ChainHook
  33. chainPriority *nftables.ChainPriority
  34. chainPolicy *nftables.ChainPolicy
  35. }
  36. type nftable struct {
  37. Proto nftables.TableFamily
  38. Filter *nftables.Table
  39. Nat *nftables.Table
  40. }
  41. // nftablesRunner implements a netfilterRunner using the netlink based nftables
  42. // library. As nftables allows for arbitrary tables and chains, there is a need
  43. // to follow conventions in order to integrate well with a surrounding
  44. // ecosystem. The rules installed by nftablesRunner have the following
  45. // properties:
  46. // - Install rules that intend to take precedence over rules installed by
  47. // other software. Tailscale provides packet filtering for tailnet traffic
  48. // inside the daemon based on the tailnet ACL rules.
  49. // - As nftables "accept" is not final, rules from high priority tables (low
  50. // numbers) will fall through to lower priority tables (high numbers). In
  51. // order to effectively be 'final', we install "jump" rules into conventional
  52. // tables and chains that will reach an accept verdict inside those tables.
  53. // - The table and chain conventions followed here are those used by
  54. // `iptables-nft` and `ufw`, so that those tools co-exist and do not
  55. // negatively affect Tailscale function.
  56. type nftablesRunner struct {
  57. conn *nftables.Conn
  58. nft4 *nftable
  59. nft6 *nftable
  60. v6Available bool
  61. v6NATAvailable bool
  62. }
  63. func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
  64. polAccept := nftables.ChainPolicyAccept
  65. table := n.getNFTByAddr(dst)
  66. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
  67. if err != nil {
  68. return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
  69. }
  70. // ensure prerouting chain exists
  71. preroutingCh, err := getOrCreateChain(n.conn, chainInfo{
  72. table: nat,
  73. name: "PREROUTING",
  74. chainType: nftables.ChainTypeNAT,
  75. chainHook: nftables.ChainHookPrerouting,
  76. chainPriority: nftables.ChainPriorityNATDest,
  77. chainPolicy: &polAccept,
  78. })
  79. if err != nil {
  80. return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
  81. }
  82. return nat, preroutingCh, nil
  83. }
  84. func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
  85. nat, preroutingCh, err := n.ensurePreroutingChain(dst)
  86. if err != nil {
  87. return err
  88. }
  89. var daddrOffset, fam, dadderLen uint32
  90. if origDst.Is4() {
  91. daddrOffset = 16
  92. dadderLen = 4
  93. fam = unix.NFPROTO_IPV4
  94. } else {
  95. daddrOffset = 24
  96. dadderLen = 16
  97. fam = unix.NFPROTO_IPV6
  98. }
  99. dnatRule := &nftables.Rule{
  100. Table: nat,
  101. Chain: preroutingCh,
  102. Exprs: []expr.Any{
  103. &expr.Payload{
  104. DestRegister: 1,
  105. Base: expr.PayloadBaseNetworkHeader,
  106. Offset: daddrOffset,
  107. Len: dadderLen,
  108. },
  109. &expr.Cmp{
  110. Op: expr.CmpOpEq,
  111. Register: 1,
  112. Data: origDst.AsSlice(),
  113. },
  114. &expr.Immediate{
  115. Register: 1,
  116. Data: dst.AsSlice(),
  117. },
  118. &expr.NAT{
  119. Type: expr.NATTypeDestNAT,
  120. Family: fam,
  121. RegAddrMin: 1,
  122. },
  123. },
  124. }
  125. n.conn.InsertRule(dnatRule)
  126. return n.conn.Flush()
  127. }
  128. func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
  129. nat, preroutingCh, err := n.ensurePreroutingChain(dst)
  130. if err != nil {
  131. return err
  132. }
  133. var famConst uint32
  134. if dst.Is4() {
  135. famConst = unix.NFPROTO_IPV4
  136. } else {
  137. famConst = unix.NFPROTO_IPV6
  138. }
  139. dnatRule := &nftables.Rule{
  140. Table: nat,
  141. Chain: preroutingCh,
  142. Exprs: []expr.Any{
  143. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  144. &expr.Cmp{
  145. Op: expr.CmpOpNeq,
  146. Register: 1,
  147. Data: []byte(tunname),
  148. },
  149. &expr.Immediate{
  150. Register: 1,
  151. Data: dst.AsSlice(),
  152. },
  153. &expr.NAT{
  154. Type: expr.NATTypeDestNAT,
  155. Family: famConst,
  156. RegAddrMin: 1,
  157. },
  158. },
  159. }
  160. n.conn.AddRule(dnatRule)
  161. return n.conn.Flush()
  162. }
  163. func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
  164. polAccept := nftables.ChainPolicyAccept
  165. table := n.getNFTByAddr(dst)
  166. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
  167. if err != nil {
  168. return fmt.Errorf("error ensuring nat table exists: %w", err)
  169. }
  170. // ensure postrouting chain exists
  171. postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{
  172. table: nat,
  173. name: "POSTROUTING",
  174. chainType: nftables.ChainTypeNAT,
  175. chainHook: nftables.ChainHookPostrouting,
  176. chainPriority: nftables.ChainPriorityNATSource,
  177. chainPolicy: &polAccept,
  178. })
  179. if err != nil {
  180. return fmt.Errorf("error ensuring postrouting chain: %w", err)
  181. }
  182. var daddrOffset, fam, daddrLen uint32
  183. if dst.Is4() {
  184. daddrOffset = 16
  185. daddrLen = 4
  186. fam = unix.NFPROTO_IPV4
  187. } else {
  188. daddrOffset = 24
  189. daddrLen = 16
  190. fam = unix.NFPROTO_IPV6
  191. }
  192. snatRule := &nftables.Rule{
  193. Table: nat,
  194. Chain: postRoutingCh,
  195. Exprs: []expr.Any{
  196. &expr.Payload{
  197. DestRegister: 1,
  198. Base: expr.PayloadBaseNetworkHeader,
  199. Offset: daddrOffset,
  200. Len: daddrLen,
  201. },
  202. &expr.Cmp{
  203. Op: expr.CmpOpEq,
  204. Register: 1,
  205. Data: dst.AsSlice(),
  206. },
  207. &expr.Immediate{
  208. Register: 1,
  209. Data: src.AsSlice(),
  210. },
  211. &expr.NAT{
  212. Type: expr.NATTypeSourceNAT,
  213. Family: fam,
  214. RegAddrMin: 1,
  215. },
  216. },
  217. }
  218. n.conn.AddRule(snatRule)
  219. return n.conn.Flush()
  220. }
  221. func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
  222. polAccept := nftables.ChainPolicyAccept
  223. table := n.getNFTByAddr(addr)
  224. filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter")
  225. if err != nil {
  226. return fmt.Errorf("error ensuring filter table: %w", err)
  227. }
  228. // ensure forwarding chain exists
  229. fwChain, err := getOrCreateChain(n.conn, chainInfo{
  230. table: filterTable,
  231. name: "FORWARD",
  232. chainType: nftables.ChainTypeFilter,
  233. chainHook: nftables.ChainHookForward,
  234. chainPriority: nftables.ChainPriorityFilter,
  235. chainPolicy: &polAccept,
  236. })
  237. if err != nil {
  238. return fmt.Errorf("error ensuring forward chain: %w", err)
  239. }
  240. clampRule := &nftables.Rule{
  241. Table: filterTable,
  242. Chain: fwChain,
  243. Exprs: []expr.Any{
  244. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  245. &expr.Cmp{
  246. Op: expr.CmpOpEq,
  247. Register: 1,
  248. Data: []byte(tun),
  249. },
  250. &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
  251. &expr.Cmp{
  252. Op: expr.CmpOpEq,
  253. Register: 1,
  254. Data: []byte{unix.IPPROTO_TCP},
  255. },
  256. &expr.Payload{
  257. DestRegister: 1,
  258. Base: expr.PayloadBaseTransportHeader,
  259. Offset: 13,
  260. Len: 1,
  261. },
  262. &expr.Bitwise{
  263. DestRegister: 1,
  264. SourceRegister: 1,
  265. Len: 1,
  266. Mask: []byte{0x02},
  267. Xor: []byte{0x00},
  268. },
  269. &expr.Cmp{
  270. Op: expr.CmpOpNeq,
  271. Register: 1,
  272. Data: []byte{0x00},
  273. },
  274. &expr.Rt{
  275. Register: 1,
  276. Key: expr.RtTCPMSS,
  277. },
  278. &expr.Byteorder{
  279. DestRegister: 1,
  280. SourceRegister: 1,
  281. Op: expr.ByteorderHton,
  282. Len: 2,
  283. Size: 2,
  284. },
  285. &expr.Exthdr{
  286. SourceRegister: 1,
  287. Type: 2,
  288. Offset: 2,
  289. Len: 2,
  290. Op: expr.ExthdrOpTcpopt,
  291. },
  292. },
  293. }
  294. n.conn.AddRule(clampRule)
  295. return n.conn.Flush()
  296. }
  297. // deleteTableIfExists deletes a nftables table via connection c if it exists
  298. // within the given family.
  299. func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
  300. t, err := getTableIfExists(c, family, name)
  301. if err != nil {
  302. return fmt.Errorf("get table: %w", err)
  303. }
  304. if t == nil {
  305. // Table does not exist, so nothing to delete.
  306. return nil
  307. }
  308. c.DelTable(t)
  309. if err := c.Flush(); err != nil {
  310. if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
  311. // Check if the table still exists. If it does not, then the error
  312. // is due to the table not existing, so we can ignore it. Maybe a
  313. // concurrent process deleted the table.
  314. return nil
  315. }
  316. return fmt.Errorf("del table: %w", err)
  317. }
  318. return nil
  319. }
  320. // getTableIfExists returns the table with the given name from the given family
  321. // if it exists. If none match, it returns (nil, nil).
  322. func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
  323. tables, err := c.ListTables()
  324. if err != nil {
  325. return nil, fmt.Errorf("get tables: %w", err)
  326. }
  327. for _, table := range tables {
  328. if table.Name == name && table.Family == family {
  329. return table, nil
  330. }
  331. }
  332. return nil, nil
  333. }
  334. // createTableIfNotExist creates a nftables table via connection c if it does
  335. // not exist within the given family.
  336. func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
  337. if t, err := getTableIfExists(c, family, name); err != nil {
  338. return nil, fmt.Errorf("get table: %w", err)
  339. } else if t != nil {
  340. return t, nil
  341. }
  342. t := c.AddTable(&nftables.Table{
  343. Family: family,
  344. Name: name,
  345. })
  346. if err := c.Flush(); err != nil {
  347. return nil, fmt.Errorf("add table: %w", err)
  348. }
  349. return t, nil
  350. }
  351. type errorChainNotFound struct {
  352. chainName string
  353. tableName string
  354. }
  355. func (e errorChainNotFound) Error() string {
  356. return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
  357. }
  358. // getChainFromTable returns the chain with the given name from the given table.
  359. // Note that a chain name is unique within a table.
  360. func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
  361. chains, err := c.ListChainsOfTableFamily(table.Family)
  362. if err != nil {
  363. return nil, fmt.Errorf("list chains: %w", err)
  364. }
  365. for _, chain := range chains {
  366. // Table family is already checked so table name is unique
  367. if chain.Table.Name == table.Name && chain.Name == name {
  368. return chain, nil
  369. }
  370. }
  371. return nil, errorChainNotFound{table.Name, name}
  372. }
  373. // isTSChain reports whether `name` begins with "ts-" (and is thus a
  374. // Tailscale-managed chain).
  375. func isTSChain(name string) bool {
  376. return strings.HasPrefix(name, "ts-")
  377. }
  378. // createChainIfNotExist creates a chain with the given name in the given table
  379. // if it does not exist.
  380. func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
  381. _, err := getOrCreateChain(c, cinfo)
  382. return err
  383. }
  384. func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) {
  385. chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
  386. if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
  387. return nil, fmt.Errorf("get chain: %w", err)
  388. } else if err == nil {
  389. // The chain already exists. If it is a TS chain, check the
  390. // type/hook/priority, but for "conventional chains" assume they're what
  391. // we expect (in case iptables-nft/ufw make minor behavior changes in
  392. // the future).
  393. if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
  394. return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
  395. }
  396. return chain, nil
  397. }
  398. chain = c.AddChain(&nftables.Chain{
  399. Name: cinfo.name,
  400. Table: cinfo.table,
  401. Type: cinfo.chainType,
  402. Hooknum: cinfo.chainHook,
  403. Priority: cinfo.chainPriority,
  404. Policy: cinfo.chainPolicy,
  405. })
  406. if err := c.Flush(); err != nil {
  407. return nil, fmt.Errorf("add chain: %w", err)
  408. }
  409. return chain, nil
  410. }
  411. // NetfilterRunner abstracts helpers to run netfilter commands. It is
  412. // implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
  413. type NetfilterRunner interface {
  414. // AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
  415. // is added only if it does not already exist.
  416. AddLoopbackRule(addr netip.Addr) error
  417. // DelLoopbackRule removes the rule added by AddLoopbackRule.
  418. DelLoopbackRule(addr netip.Addr) error
  419. // AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
  420. // "POSTROUTING" to jump from those chains to tailscale chains.
  421. AddHooks() error
  422. // DelHooks deletes rules added by AddHooks.
  423. DelHooks(logf logger.Logf) error
  424. // AddChains creates custom Tailscale chains.
  425. AddChains() error
  426. // DelChains removes chains added by AddChains.
  427. DelChains() error
  428. // AddBase adds rules reused by different other rules.
  429. AddBase(tunname string) error
  430. // DelBase removes rules added by AddBase.
  431. DelBase() error
  432. // AddSNATRule adds the netfilter rule to SNAT incoming traffic over
  433. // the Tailscale interface destined for local subnets. An error is
  434. // returned if the rule already exists.
  435. AddSNATRule() error
  436. // DelSNATRule removes the rule added by AddSNATRule.
  437. DelSNATRule() error
  438. // HasIPV6 reports true if the system supports IPv6.
  439. HasIPV6() bool
  440. // HasIPV6NAT reports true if the system supports IPv6 NAT.
  441. HasIPV6NAT() bool
  442. // AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
  443. // destined for the given original destination to the given new destination.
  444. // This is used to forward all traffic destined for the Tailscale interface
  445. // to the provided destination, as used in the Kubernetes ingress proxies.
  446. AddDNATRule(origDst, dst netip.Addr) error
  447. // AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT
  448. // traffic destined for dst to src.
  449. // This is used to forward traffic destined for the local machine over
  450. // the Tailscale interface, as used in the Kubernetes egress proxies.
  451. AddSNATRuleForDst(src, dst netip.Addr) error
  452. // DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
  453. // all traffic inbound from any interface except exemptInterface to dst.
  454. // This is used to forward traffic destined for the local machine over
  455. // the Tailscale interface, as used in the Kubernetes egress proxies.//
  456. DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error
  457. // ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for
  458. // traffic destined for the provided tun interface.
  459. ClampMSSToPMTU(tun string, addr netip.Addr) error
  460. // AddMagicsockPortRule adds a rule to the ts-input chain to accept
  461. // incoming traffic on the specified port, to allow magicsock to
  462. // communicate.
  463. AddMagicsockPortRule(port uint16, network string) error
  464. // DelMagicsockPortRule removes the rule created by AddMagicsockPortRule,
  465. // if it exists.
  466. DelMagicsockPortRule(port uint16, network string) error
  467. }
  468. // New creates a NetfilterRunner, auto-detecting whether to use
  469. // nftables or iptables.
  470. // As nftables is still experimental, iptables will be used unless
  471. // either the TS_DEBUG_FIREWALL_MODE environment variable, or the prefHint
  472. // parameter, is set to one of "nftables" or "auto".
  473. func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) {
  474. mode := detectFirewallMode(logf, prefHint)
  475. switch mode {
  476. case FirewallModeIPTables:
  477. return newIPTablesRunner(logf)
  478. case FirewallModeNfTables:
  479. return newNfTablesRunner(logf)
  480. default:
  481. return nil, fmt.Errorf("unknown firewall mode %v", mode)
  482. }
  483. }
  484. // newNfTablesRunner creates a new nftablesRunner without guaranteeing
  485. // the existence of the tables and chains.
  486. func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
  487. conn, err := nftables.New()
  488. if err != nil {
  489. return nil, fmt.Errorf("nftables connection: %w", err)
  490. }
  491. nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
  492. v6err := checkIPv6(logf)
  493. if v6err != nil {
  494. logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
  495. }
  496. supportsV6 := v6err == nil
  497. supportsV6NAT := supportsV6 && checkSupportsV6NAT()
  498. var nft6 *nftable
  499. if supportsV6 {
  500. logf("v6nat availability: %v", supportsV6NAT)
  501. nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
  502. }
  503. // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
  504. return &nftablesRunner{
  505. conn: conn,
  506. nft4: nft4,
  507. nft6: nft6,
  508. v6Available: supportsV6,
  509. v6NATAvailable: supportsV6NAT,
  510. }, nil
  511. }
  512. // newLoadSaddrExpr creates a new nftables expression that loads the source
  513. // address of the packet into the given register.
  514. func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
  515. switch proto {
  516. case nftables.TableFamilyIPv4:
  517. return &expr.Payload{
  518. DestRegister: destReg,
  519. Base: expr.PayloadBaseNetworkHeader,
  520. Offset: 12,
  521. Len: 4,
  522. }, nil
  523. case nftables.TableFamilyIPv6:
  524. return &expr.Payload{
  525. DestRegister: destReg,
  526. Base: expr.PayloadBaseNetworkHeader,
  527. Offset: 8,
  528. Len: 16,
  529. }, nil
  530. default:
  531. return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
  532. }
  533. }
  534. // newLoadDportExpr creates a new nftables express that loads the desination port
  535. // of a TCP/UDP packet into the given register.
  536. func newLoadDportExpr(destReg uint32) expr.Any {
  537. return &expr.Payload{
  538. DestRegister: destReg,
  539. Base: expr.PayloadBaseTransportHeader,
  540. Offset: 2,
  541. Len: 2,
  542. }
  543. }
  544. // HasIPV6 reports true if the system supports IPv6.
  545. func (n *nftablesRunner) HasIPV6() bool {
  546. return n.v6Available
  547. }
  548. // HasIPV6NAT returns true if the system supports IPv6 NAT.
  549. func (n *nftablesRunner) HasIPV6NAT() bool {
  550. return n.v6NATAvailable
  551. }
  552. // findRule iterates through the rules to find the rule with matching expressions.
  553. func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
  554. rules, err := conn.GetRules(rule.Table, rule.Chain)
  555. if err != nil {
  556. return nil, fmt.Errorf("get nftables rules: %w", err)
  557. }
  558. if len(rules) == 0 {
  559. return nil, nil
  560. }
  561. ruleLoop:
  562. for _, r := range rules {
  563. if len(r.Exprs) != len(rule.Exprs) {
  564. continue
  565. }
  566. for i, e := range r.Exprs {
  567. // Skip counter expressions, as they will not match.
  568. if _, ok := e.(*expr.Counter); ok {
  569. continue
  570. }
  571. if !reflect.DeepEqual(e, rule.Exprs[i]) {
  572. continue ruleLoop
  573. }
  574. }
  575. return r, nil
  576. }
  577. return nil, nil
  578. }
  579. func createLoopbackRule(
  580. proto nftables.TableFamily,
  581. table *nftables.Table,
  582. chain *nftables.Chain,
  583. addr netip.Addr,
  584. ) (*nftables.Rule, error) {
  585. saddrExpr, err := newLoadSaddrExpr(proto, 1)
  586. if err != nil {
  587. return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
  588. }
  589. loopBackRule := &nftables.Rule{
  590. Table: table,
  591. Chain: chain,
  592. Exprs: []expr.Any{
  593. &expr.Meta{
  594. Key: expr.MetaKeyIIFNAME,
  595. Register: 1,
  596. },
  597. &expr.Cmp{
  598. Op: expr.CmpOpEq,
  599. Register: 1,
  600. Data: []byte("lo"),
  601. },
  602. saddrExpr,
  603. &expr.Cmp{
  604. Op: expr.CmpOpEq,
  605. Register: 1,
  606. Data: addr.AsSlice(),
  607. },
  608. &expr.Counter{},
  609. &expr.Verdict{
  610. Kind: expr.VerdictAccept,
  611. },
  612. },
  613. }
  614. return loopBackRule, nil
  615. }
  616. // insertLoopbackRule inserts the TS loop back rule into
  617. // the given chain as the first rule if it does not exist.
  618. func insertLoopbackRule(
  619. conn *nftables.Conn, proto nftables.TableFamily,
  620. table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
  621. loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
  622. if err != nil {
  623. return fmt.Errorf("create loopback rule: %w", err)
  624. }
  625. // If TestDial is set, we are running in test mode and we should not
  626. // find rule because header will mismatch.
  627. if conn.TestDial == nil {
  628. // Check if the rule already exists.
  629. rule, err := findRule(conn, loopBackRule)
  630. if err != nil {
  631. return fmt.Errorf("find rule: %w", err)
  632. }
  633. if rule != nil {
  634. // Rule already exists, no need to insert.
  635. return nil
  636. }
  637. }
  638. // This inserts the rule to the top of the chain
  639. _ = conn.InsertRule(loopBackRule)
  640. if err = conn.Flush(); err != nil {
  641. return fmt.Errorf("insert rule: %w", err)
  642. }
  643. return nil
  644. }
  645. // getNFTByAddr returns the nftables with correct IP family
  646. // that we will be using for the given address.
  647. func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
  648. if addr.Is6() {
  649. return n.nft6
  650. }
  651. return n.nft4
  652. }
  653. // AddLoopbackRule adds an nftables rule to permit loopback traffic to
  654. // a local Tailscale IP. This rule is added only if it does not already exist.
  655. func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
  656. nf := n.getNFTByAddr(addr)
  657. inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
  658. if err != nil {
  659. return fmt.Errorf("get input chain: %w", err)
  660. }
  661. if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
  662. return fmt.Errorf("add loopback rule: %w", err)
  663. }
  664. return nil
  665. }
  666. // DelLoopbackRule removes the nftables rule permitting loopback
  667. // traffic to a Tailscale IP.
  668. func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
  669. nf := n.getNFTByAddr(addr)
  670. inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
  671. if err != nil {
  672. return fmt.Errorf("get input chain: %w", err)
  673. }
  674. loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
  675. if err != nil {
  676. return fmt.Errorf("create loopback rule: %w", err)
  677. }
  678. existingLoopBackRule, err := findRule(n.conn, loopBackRule)
  679. if err != nil {
  680. return fmt.Errorf("find loop back rule: %w", err)
  681. }
  682. if existingLoopBackRule == nil {
  683. // Rule does not exist, no need to delete.
  684. return nil
  685. }
  686. if err := n.conn.DelRule(existingLoopBackRule); err != nil {
  687. return fmt.Errorf("delete rule: %w", err)
  688. }
  689. return n.conn.Flush()
  690. }
  691. // getTables gets the available nftable in nftables runner.
  692. func (n *nftablesRunner) getTables() []*nftable {
  693. if n.v6Available {
  694. return []*nftable{n.nft4, n.nft6}
  695. }
  696. return []*nftable{n.nft4}
  697. }
  698. // getNATTables gets the available nftable in nftables runner.
  699. // If the system does not support IPv6 NAT, only the IPv4 nftable
  700. // will be returned.
  701. func (n *nftablesRunner) getNATTables() []*nftable {
  702. if n.v6NATAvailable {
  703. return n.getTables()
  704. }
  705. return []*nftable{n.nft4}
  706. }
  707. // AddChains creates custom Tailscale chains in netfilter via nftables
  708. // if the ts-chain doesn't already exist.
  709. func (n *nftablesRunner) AddChains() error {
  710. polAccept := nftables.ChainPolicyAccept
  711. for _, table := range n.getTables() {
  712. // Create the filter table if it doesn't exist, this table name is the same
  713. // as the name used by iptables-nft and ufw. We install rules into the
  714. // same conventional table so that `accept` verdicts from our jump
  715. // chains are conclusive.
  716. filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
  717. if err != nil {
  718. return fmt.Errorf("create table: %w", err)
  719. }
  720. table.Filter = filter
  721. // Adding the "conventional chains" that are used by iptables-nft and ufw.
  722. if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
  723. return fmt.Errorf("create forward chain: %w", err)
  724. }
  725. if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
  726. return fmt.Errorf("create input chain: %w", err)
  727. }
  728. // Adding the tailscale chains that contain our rules.
  729. if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
  730. return fmt.Errorf("create forward chain: %w", err)
  731. }
  732. if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
  733. return fmt.Errorf("create input chain: %w", err)
  734. }
  735. }
  736. for _, table := range n.getNATTables() {
  737. // Create the nat table if it doesn't exist, this table name is the same
  738. // as the name used by iptables-nft and ufw. We install rules into the
  739. // same conventional table so that `accept` verdicts from our jump
  740. // chains are conclusive.
  741. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
  742. if err != nil {
  743. return fmt.Errorf("create table: %w", err)
  744. }
  745. table.Nat = nat
  746. // Adding the "conventional chains" that are used by iptables-nft and ufw.
  747. if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
  748. return fmt.Errorf("create postrouting chain: %w", err)
  749. }
  750. // Adding the tailscale chain that contains our rules.
  751. if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
  752. return fmt.Errorf("create postrouting chain: %w", err)
  753. }
  754. }
  755. return n.conn.Flush()
  756. }
  757. // These are dummy chains and tables we create to detect if nftables is
  758. // available. We create them, then delete them. If we can create and delete
  759. // them, then we can use nftables. If we can't, then we assume that we're
  760. // running on a system that doesn't support nftables. See
  761. // createDummyPostroutingChains.
  762. const (
  763. tsDummyChainName = "ts-test-postrouting"
  764. tsDummyTableName = "ts-test-nat"
  765. )
  766. // createDummyPostroutingChains creates dummy postrouting chains in netfilter
  767. // via netfilter via nftables, as a last resort measure to detect that nftables
  768. // can be used. It cleans up the dummy chains after creation.
  769. func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
  770. polAccept := ptr.To(nftables.ChainPolicyAccept)
  771. for _, table := range n.getNATTables() {
  772. nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
  773. if err != nil {
  774. return fmt.Errorf("create nat table: %w", err)
  775. }
  776. defer func(fm nftables.TableFamily) {
  777. if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil {
  778. retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
  779. }
  780. }(table.Proto)
  781. table.Nat = nat
  782. if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
  783. return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
  784. }
  785. if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
  786. return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
  787. }
  788. }
  789. return nil
  790. }
  791. // deleteChainIfExists deletes a chain if it exists.
  792. func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
  793. chain, err := getChainFromTable(c, table, name)
  794. if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
  795. return fmt.Errorf("get chain: %w", err)
  796. } else if err != nil {
  797. // If the chain doesn't exist, we don't need to delete it.
  798. return nil
  799. }
  800. c.FlushChain(chain)
  801. c.DelChain(chain)
  802. if err := c.Flush(); err != nil {
  803. return fmt.Errorf("flush and delete chain: %w", err)
  804. }
  805. return nil
  806. }
  807. // DelChains removes the custom Tailscale chains from netfilter via nftables.
  808. func (n *nftablesRunner) DelChains() error {
  809. for _, table := range n.getTables() {
  810. if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
  811. return fmt.Errorf("delete chain: %w", err)
  812. }
  813. if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
  814. return fmt.Errorf("delete chain: %w", err)
  815. }
  816. }
  817. if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
  818. return fmt.Errorf("delete chain: %w", err)
  819. }
  820. if n.v6NATAvailable {
  821. if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
  822. return fmt.Errorf("delete chain: %w", err)
  823. }
  824. }
  825. if err := n.conn.Flush(); err != nil {
  826. return fmt.Errorf("flush: %w", err)
  827. }
  828. return nil
  829. }
  830. // createHookRule creates a rule to jump from a hooked chain to a regular chain.
  831. func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
  832. exprs := []expr.Any{
  833. &expr.Counter{},
  834. &expr.Verdict{
  835. Kind: expr.VerdictJump,
  836. Chain: toChainName,
  837. },
  838. }
  839. rule := &nftables.Rule{
  840. Table: table,
  841. Chain: fromChain,
  842. Exprs: exprs,
  843. }
  844. return rule
  845. }
  846. // addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
  847. func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
  848. rule := createHookRule(table, fromChain, toChainName)
  849. _ = conn.InsertRule(rule)
  850. if err := conn.Flush(); err != nil {
  851. return fmt.Errorf("flush add rule: %w", err)
  852. }
  853. return nil
  854. }
  855. // AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
  856. // in tables and jump from those chains to tailscale chains.
  857. func (n *nftablesRunner) AddHooks() error {
  858. conn := n.conn
  859. for _, table := range n.getTables() {
  860. inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
  861. if err != nil {
  862. return fmt.Errorf("get INPUT chain: %w", err)
  863. }
  864. err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
  865. if err != nil {
  866. return fmt.Errorf("Addhook: %w", err)
  867. }
  868. forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
  869. if err != nil {
  870. return fmt.Errorf("get FORWARD chain: %w", err)
  871. }
  872. err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
  873. if err != nil {
  874. return fmt.Errorf("Addhook: %w", err)
  875. }
  876. }
  877. for _, table := range n.getNATTables() {
  878. postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
  879. if err != nil {
  880. return fmt.Errorf("get INPUT chain: %w", err)
  881. }
  882. err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
  883. if err != nil {
  884. return fmt.Errorf("Addhook: %w", err)
  885. }
  886. }
  887. return nil
  888. }
  889. // delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
  890. func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
  891. rule := createHookRule(table, fromChain, toChainName)
  892. existingRule, err := findRule(conn, rule)
  893. if err != nil {
  894. return fmt.Errorf("Failed to find hook rule: %w", err)
  895. }
  896. if existingRule == nil {
  897. return nil
  898. }
  899. _ = conn.DelRule(existingRule)
  900. if err := conn.Flush(); err != nil {
  901. return fmt.Errorf("flush del hook rule: %w", err)
  902. }
  903. return nil
  904. }
  905. // DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
  906. func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
  907. conn := n.conn
  908. for _, table := range n.getTables() {
  909. inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
  910. if err != nil {
  911. return fmt.Errorf("get INPUT chain: %w", err)
  912. }
  913. err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
  914. if err != nil {
  915. return fmt.Errorf("delhook: %w", err)
  916. }
  917. forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
  918. if err != nil {
  919. return fmt.Errorf("get FORWARD chain: %w", err)
  920. }
  921. err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
  922. if err != nil {
  923. return fmt.Errorf("delhook: %w", err)
  924. }
  925. }
  926. for _, table := range n.getNATTables() {
  927. postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
  928. if err != nil {
  929. return fmt.Errorf("get INPUT chain: %w", err)
  930. }
  931. err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
  932. if err != nil {
  933. return fmt.Errorf("delhook: %w", err)
  934. }
  935. }
  936. return nil
  937. }
  938. // maskof returns the mask of the given prefix in big endian bytes.
  939. func maskof(pfx netip.Prefix) []byte {
  940. mask := make([]byte, 4)
  941. binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
  942. return mask
  943. }
  944. // createRangeRule creates a rule that matches packets with source IP from the give
  945. // range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
  946. // and makes the given decision. Only IPv4 is supported.
  947. func createRangeRule(
  948. table *nftables.Table, chain *nftables.Chain,
  949. tunname string, rng netip.Prefix, decision expr.VerdictKind,
  950. ) (*nftables.Rule, error) {
  951. if rng.Addr().Is6() {
  952. return nil, errors.New("IPv6 is not supported")
  953. }
  954. saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
  955. if err != nil {
  956. return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
  957. }
  958. netip := rng.Addr().AsSlice()
  959. mask := maskof(rng)
  960. rule := &nftables.Rule{
  961. Table: table,
  962. Chain: chain,
  963. Exprs: []expr.Any{
  964. &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
  965. &expr.Cmp{
  966. Op: expr.CmpOpNeq,
  967. Register: 1,
  968. Data: []byte(tunname),
  969. },
  970. saddrExpr,
  971. &expr.Bitwise{
  972. SourceRegister: 1,
  973. DestRegister: 1,
  974. Len: 4,
  975. Mask: mask,
  976. Xor: []byte{0x00, 0x00, 0x00, 0x00},
  977. },
  978. &expr.Cmp{
  979. Op: expr.CmpOpEq,
  980. Register: 1,
  981. Data: netip,
  982. },
  983. &expr.Counter{},
  984. &expr.Verdict{
  985. Kind: decision,
  986. },
  987. },
  988. }
  989. return rule, nil
  990. }
  991. // addReturnChromeOSVMRangeRule adds a rule to return if the source IP
  992. // is in the ChromeOS VM range.
  993. func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  994. rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
  995. if err != nil {
  996. return fmt.Errorf("create rule: %w", err)
  997. }
  998. _ = c.AddRule(rule)
  999. if err = c.Flush(); err != nil {
  1000. return fmt.Errorf("add rule: %w", err)
  1001. }
  1002. return nil
  1003. }
  1004. // addDropCGNATRangeRule adds a rule to drop if the source IP is in the
  1005. // CGNAT range.
  1006. func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1007. rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
  1008. if err != nil {
  1009. return fmt.Errorf("create rule: %w", err)
  1010. }
  1011. _ = c.AddRule(rule)
  1012. if err = c.Flush(); err != nil {
  1013. return fmt.Errorf("add rule: %w", err)
  1014. }
  1015. return nil
  1016. }
  1017. // createSetSubnetRouteMarkRule creates a rule to set the subnet route
  1018. // mark if the packet is from the given interface.
  1019. func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
  1020. hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
  1021. hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
  1022. rule := &nftables.Rule{
  1023. Table: table,
  1024. Chain: chain,
  1025. Exprs: []expr.Any{
  1026. &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
  1027. &expr.Cmp{
  1028. Op: expr.CmpOpEq,
  1029. Register: 1,
  1030. Data: []byte(tunname),
  1031. },
  1032. &expr.Counter{},
  1033. &expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
  1034. &expr.Bitwise{
  1035. SourceRegister: 1,
  1036. DestRegister: 1,
  1037. Len: 4,
  1038. Mask: hexTsFwmarkMaskNeg,
  1039. Xor: hexTSSubnetRouteMark,
  1040. },
  1041. &expr.Meta{
  1042. Key: expr.MetaKeyMARK,
  1043. SourceRegister: true,
  1044. Register: 1,
  1045. },
  1046. },
  1047. }
  1048. return rule, nil
  1049. }
  1050. // addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
  1051. // if the packet is from the given interface.
  1052. func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1053. rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
  1054. if err != nil {
  1055. return fmt.Errorf("create rule: %w", err)
  1056. }
  1057. _ = c.AddRule(rule)
  1058. if err := c.Flush(); err != nil {
  1059. return fmt.Errorf("add rule: %w", err)
  1060. }
  1061. return nil
  1062. }
  1063. // createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
  1064. // outgoing packets from the CGNAT range.
  1065. func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
  1066. _, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
  1067. if err != nil {
  1068. return nil, fmt.Errorf("parse cidr: %v", err)
  1069. }
  1070. mask, err := hex.DecodeString(ipNet.Mask.String())
  1071. if err != nil {
  1072. return nil, fmt.Errorf("decode mask: %v", err)
  1073. }
  1074. netip := ipNet.IP.Mask(ipNet.Mask).To4()
  1075. saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
  1076. if err != nil {
  1077. return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
  1078. }
  1079. rule := &nftables.Rule{
  1080. Table: table,
  1081. Chain: chain,
  1082. Exprs: []expr.Any{
  1083. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  1084. &expr.Cmp{
  1085. Op: expr.CmpOpEq,
  1086. Register: 1,
  1087. Data: []byte(tunname),
  1088. },
  1089. saddrExpr,
  1090. &expr.Bitwise{
  1091. SourceRegister: 1,
  1092. DestRegister: 1,
  1093. Len: 4,
  1094. Mask: mask,
  1095. Xor: []byte{0x00, 0x00, 0x00, 0x00},
  1096. },
  1097. &expr.Cmp{
  1098. Op: expr.CmpOpEq,
  1099. Register: 1,
  1100. Data: netip,
  1101. },
  1102. &expr.Counter{},
  1103. &expr.Verdict{
  1104. Kind: expr.VerdictDrop,
  1105. },
  1106. },
  1107. }
  1108. return rule, nil
  1109. }
  1110. // addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
  1111. // outgoing packets from the CGNAT range.
  1112. func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1113. rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
  1114. if err != nil {
  1115. return fmt.Errorf("create rule: %w", err)
  1116. }
  1117. _ = conn.AddRule(rule)
  1118. if err := conn.Flush(); err != nil {
  1119. return fmt.Errorf("add rule: %w", err)
  1120. }
  1121. return nil
  1122. }
  1123. // createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
  1124. // from the given interface.
  1125. func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
  1126. return &nftables.Rule{
  1127. Table: table,
  1128. Chain: chain,
  1129. Exprs: []expr.Any{
  1130. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  1131. &expr.Cmp{
  1132. Op: expr.CmpOpEq,
  1133. Register: 1,
  1134. Data: []byte(tunname),
  1135. },
  1136. &expr.Counter{},
  1137. &expr.Verdict{
  1138. Kind: expr.VerdictAccept,
  1139. },
  1140. },
  1141. }
  1142. }
  1143. // addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
  1144. // from the given interface.
  1145. func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1146. rule := createAcceptOutgoingPacketRule(table, chain, tunname)
  1147. _ = conn.AddRule(rule)
  1148. if err := conn.Flush(); err != nil {
  1149. return fmt.Errorf("flush add rule: %w", err)
  1150. }
  1151. return nil
  1152. }
  1153. // createAcceptOnPortRule creates a rule to accept incoming packets to
  1154. // a given destination UDP port.
  1155. func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule {
  1156. portBytes := make([]byte, 2)
  1157. binary.BigEndian.PutUint16(portBytes, port)
  1158. return &nftables.Rule{
  1159. Table: table,
  1160. Chain: chain,
  1161. Exprs: []expr.Any{
  1162. &expr.Meta{
  1163. Key: expr.MetaKeyL4PROTO,
  1164. Register: 1,
  1165. },
  1166. &expr.Cmp{
  1167. Op: expr.CmpOpEq,
  1168. Register: 1,
  1169. Data: []byte{unix.IPPROTO_UDP},
  1170. },
  1171. newLoadDportExpr(1),
  1172. &expr.Cmp{
  1173. Op: expr.CmpOpEq,
  1174. Register: 1,
  1175. Data: portBytes,
  1176. },
  1177. &expr.Counter{},
  1178. &expr.Verdict{
  1179. Kind: expr.VerdictAccept,
  1180. },
  1181. },
  1182. }
  1183. }
  1184. // addAcceptOnPortRule adds a rule to accept incoming packets to
  1185. // a given destination UDP port.
  1186. func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
  1187. rule := createAcceptOnPortRule(table, chain, port)
  1188. _ = conn.AddRule(rule)
  1189. if err := conn.Flush(); err != nil {
  1190. return fmt.Errorf("flush add rule: %w", err)
  1191. }
  1192. return nil
  1193. }
  1194. // addAcceptOnPortRule removes a rule to accept incoming packets to
  1195. // a given destination UDP port.
  1196. func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
  1197. rule := createAcceptOnPortRule(table, chain, port)
  1198. rule, err := findRule(conn, rule)
  1199. if err != nil {
  1200. return fmt.Errorf("find rule: %v", err)
  1201. }
  1202. _ = conn.DelRule(rule)
  1203. if err := conn.Flush(); err != nil {
  1204. return fmt.Errorf("flush del rule: %w", err)
  1205. }
  1206. return nil
  1207. }
  1208. // AddMagicsockPortRule adds a rule to nftables to allow incoming traffic on
  1209. // the specified UDP port, so magicsock can accept incoming connections.
  1210. // network must be either "udp4" or "udp6" - this determines whether the rule
  1211. // is added for IPv4 or IPv6.
  1212. func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
  1213. var filterTable *nftables.Table
  1214. switch network {
  1215. case "udp4":
  1216. filterTable = n.nft4.Filter
  1217. case "udp6":
  1218. filterTable = n.nft6.Filter
  1219. default:
  1220. return fmt.Errorf("unsupported network %s", network)
  1221. }
  1222. inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
  1223. if err != nil {
  1224. return fmt.Errorf("get input chain: %v", err)
  1225. }
  1226. err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port)
  1227. if err != nil {
  1228. return fmt.Errorf("add accept on port rule: %v", err)
  1229. }
  1230. return nil
  1231. }
  1232. // DelMagicsockPortRule removes a rule added by AddMagicsockPortRule to accept
  1233. // incoming traffic on a particular UDP port.
  1234. // network must be either "udp4" or "udp6" - this determines whether the rule
  1235. // is removed for IPv4 or IPv6.
  1236. func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
  1237. var filterTable *nftables.Table
  1238. switch network {
  1239. case "udp4":
  1240. filterTable = n.nft4.Filter
  1241. case "udp6":
  1242. filterTable = n.nft6.Filter
  1243. default:
  1244. return fmt.Errorf("unsupported network %s", network)
  1245. }
  1246. inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
  1247. if err != nil {
  1248. return fmt.Errorf("get input chain: %v", err)
  1249. }
  1250. err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port)
  1251. if err != nil {
  1252. return fmt.Errorf("add accept on port rule: %v", err)
  1253. }
  1254. return nil
  1255. }
  1256. // createAcceptIncomingPacketRule creates a rule to accept incoming packets to
  1257. // the given interface.
  1258. func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
  1259. return &nftables.Rule{
  1260. Table: table,
  1261. Chain: chain,
  1262. Exprs: []expr.Any{
  1263. &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
  1264. &expr.Cmp{
  1265. Op: expr.CmpOpEq,
  1266. Register: 1,
  1267. Data: []byte(tunname),
  1268. },
  1269. &expr.Counter{},
  1270. &expr.Verdict{
  1271. Kind: expr.VerdictAccept,
  1272. },
  1273. },
  1274. }
  1275. }
  1276. func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1277. rule := createAcceptIncomingPacketRule(table, chain, tunname)
  1278. _ = conn.AddRule(rule)
  1279. if err := conn.Flush(); err != nil {
  1280. return fmt.Errorf("flush add rule: %w", err)
  1281. }
  1282. return nil
  1283. }
  1284. // AddBase adds some basic processing rules.
  1285. func (n *nftablesRunner) AddBase(tunname string) error {
  1286. if err := n.addBase4(tunname); err != nil {
  1287. return fmt.Errorf("add base v4: %w", err)
  1288. }
  1289. if n.HasIPV6() {
  1290. if err := n.addBase6(tunname); err != nil {
  1291. return fmt.Errorf("add base v6: %w", err)
  1292. }
  1293. }
  1294. return nil
  1295. }
  1296. // addBase4 adds some basic IPv4 processing rules.
  1297. func (n *nftablesRunner) addBase4(tunname string) error {
  1298. conn := n.conn
  1299. inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
  1300. if err != nil {
  1301. return fmt.Errorf("get input chain v4: %v", err)
  1302. }
  1303. if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
  1304. return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
  1305. }
  1306. if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
  1307. return fmt.Errorf("add drop cgnat range rule v4: %w", err)
  1308. }
  1309. if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
  1310. return fmt.Errorf("add accept incoming packet rule v4: %w", err)
  1311. }
  1312. forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
  1313. if err != nil {
  1314. return fmt.Errorf("get forward chain v4: %v", err)
  1315. }
  1316. if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
  1317. return fmt.Errorf("add set subnet route mark rule v4: %w", err)
  1318. }
  1319. if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
  1320. return fmt.Errorf("add match subnet route mark rule v4: %w", err)
  1321. }
  1322. if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
  1323. return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
  1324. }
  1325. if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
  1326. return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
  1327. }
  1328. if err = conn.Flush(); err != nil {
  1329. return fmt.Errorf("flush base v4: %w", err)
  1330. }
  1331. return nil
  1332. }
  1333. // addBase6 adds some basic IPv6 processing rules.
  1334. func (n *nftablesRunner) addBase6(tunname string) error {
  1335. conn := n.conn
  1336. inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
  1337. if err != nil {
  1338. return fmt.Errorf("get input chain v4: %v", err)
  1339. }
  1340. if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
  1341. return fmt.Errorf("add accept incoming packet rule v6: %w", err)
  1342. }
  1343. forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
  1344. if err != nil {
  1345. return fmt.Errorf("get forward chain v6: %w", err)
  1346. }
  1347. if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
  1348. return fmt.Errorf("add set subnet route mark rule v6: %w", err)
  1349. }
  1350. if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
  1351. return fmt.Errorf("add match subnet route mark rule v6: %w", err)
  1352. }
  1353. if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
  1354. return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
  1355. }
  1356. if err = conn.Flush(); err != nil {
  1357. return fmt.Errorf("flush base v6: %w", err)
  1358. }
  1359. return nil
  1360. }
  1361. // DelBase empties, but does not remove, custom Tailscale chains from
  1362. // netfilter via iptables.
  1363. func (n *nftablesRunner) DelBase() error {
  1364. conn := n.conn
  1365. for _, table := range n.getTables() {
  1366. inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
  1367. if err != nil {
  1368. return fmt.Errorf("get input chain: %v", err)
  1369. }
  1370. conn.FlushChain(inputChain)
  1371. forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
  1372. if err != nil {
  1373. return fmt.Errorf("get forward chain: %v", err)
  1374. }
  1375. conn.FlushChain(forwardChain)
  1376. }
  1377. for _, table := range n.getNATTables() {
  1378. postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
  1379. if err != nil {
  1380. return fmt.Errorf("get postrouting chain v4: %v", err)
  1381. }
  1382. conn.FlushChain(postrouteChain)
  1383. }
  1384. return conn.Flush()
  1385. }
  1386. // createMatchSubnetRouteMarkRule creates a rule that matches packets
  1387. // with the subnet route mark and takes the specified action.
  1388. func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
  1389. hexTSFwmarkMask := getTailscaleFwmarkMask()
  1390. hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
  1391. var endAction expr.Any
  1392. endAction = &expr.Verdict{Kind: expr.VerdictAccept}
  1393. if action == Masq {
  1394. endAction = &expr.Masq{}
  1395. }
  1396. exprs := []expr.Any{
  1397. &expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
  1398. &expr.Bitwise{
  1399. SourceRegister: 1,
  1400. DestRegister: 1,
  1401. Len: 4,
  1402. Mask: hexTSFwmarkMask,
  1403. Xor: []byte{0x00, 0x00, 0x00, 0x00},
  1404. },
  1405. &expr.Cmp{
  1406. Op: expr.CmpOpEq,
  1407. Register: 1,
  1408. Data: hexTSSubnetRouteMark,
  1409. },
  1410. &expr.Counter{},
  1411. endAction,
  1412. }
  1413. rule := &nftables.Rule{
  1414. Table: table,
  1415. Chain: chain,
  1416. Exprs: exprs,
  1417. }
  1418. return rule, nil
  1419. }
  1420. // addMatchSubnetRouteMarkRule adds a rule that matches packets with
  1421. // the subnet route mark and takes the specified action.
  1422. func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
  1423. rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
  1424. if err != nil {
  1425. return fmt.Errorf("create match subnet route mark rule: %w", err)
  1426. }
  1427. _ = conn.AddRule(rule)
  1428. if err := conn.Flush(); err != nil {
  1429. return fmt.Errorf("flush add rule: %w", err)
  1430. }
  1431. return nil
  1432. }
  1433. // AddSNATRule adds a netfilter rule to SNAT traffic destined for
  1434. // local subnets.
  1435. func (n *nftablesRunner) AddSNATRule() error {
  1436. conn := n.conn
  1437. for _, table := range n.getNATTables() {
  1438. chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
  1439. if err != nil {
  1440. return fmt.Errorf("get postrouting chain v4: %w", err)
  1441. }
  1442. if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
  1443. return fmt.Errorf("add match subnet route mark rule v4: %w", err)
  1444. }
  1445. }
  1446. if err := conn.Flush(); err != nil {
  1447. return fmt.Errorf("flush add SNAT rule: %w", err)
  1448. }
  1449. return nil
  1450. }
  1451. // DelSNATRule removes the netfilter rule to SNAT traffic destined for
  1452. // local subnets. An error is returned if the rule does not exist.
  1453. func (n *nftablesRunner) DelSNATRule() error {
  1454. conn := n.conn
  1455. hexTSFwmarkMask := getTailscaleFwmarkMask()
  1456. hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
  1457. exprs := []expr.Any{
  1458. &expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
  1459. &expr.Bitwise{
  1460. SourceRegister: 1,
  1461. DestRegister: 1,
  1462. Len: 4,
  1463. Mask: hexTSFwmarkMask,
  1464. },
  1465. &expr.Cmp{
  1466. Op: expr.CmpOpEq,
  1467. Register: 1,
  1468. Data: hexTSSubnetRouteMark,
  1469. },
  1470. &expr.Counter{},
  1471. &expr.Masq{},
  1472. }
  1473. for _, table := range n.getNATTables() {
  1474. chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
  1475. if err != nil {
  1476. return fmt.Errorf("get postrouting chain v4: %w", err)
  1477. }
  1478. rule := &nftables.Rule{
  1479. Table: table.Nat,
  1480. Chain: chain,
  1481. Exprs: exprs,
  1482. }
  1483. SNATRule, err := findRule(conn, rule)
  1484. if err != nil {
  1485. return fmt.Errorf("find SNAT rule v4: %w", err)
  1486. }
  1487. if SNATRule != nil {
  1488. _ = conn.DelRule(SNATRule)
  1489. }
  1490. }
  1491. if err := conn.Flush(); err != nil {
  1492. return fmt.Errorf("flush del SNAT rule: %w", err)
  1493. }
  1494. return nil
  1495. }
  1496. // cleanupChain removes a jump rule from hookChainName to tsChainName, and then
  1497. // the entire chain tsChainName. Errors are logged, but attempts to remove both
  1498. // the jump rule and chain continue even if one errors.
  1499. func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
  1500. // remove the jump first, before removing the jump destination.
  1501. defaultChain, err := getChainFromTable(conn, table, hookChainName)
  1502. if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
  1503. logf("cleanup: did not find default chain: %s", err)
  1504. }
  1505. if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
  1506. // delete hook in convention chain
  1507. _ = delHookRule(conn, table, defaultChain, tsChainName)
  1508. }
  1509. tsChain, err := getChainFromTable(conn, table, tsChainName)
  1510. if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
  1511. logf("cleanup: did not find ts-chain: %s", err)
  1512. }
  1513. if tsChain != nil {
  1514. // flush and delete ts-chain
  1515. conn.FlushChain(tsChain)
  1516. conn.DelChain(tsChain)
  1517. err = conn.Flush()
  1518. logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
  1519. }
  1520. }
  1521. // NfTablesCleanUp removes all Tailscale added nftables rules.
  1522. // Any errors that occur are logged to the provided logf.
  1523. func NfTablesCleanUp(logf logger.Logf) {
  1524. conn, err := nftables.New()
  1525. if err != nil {
  1526. logf("cleanup: nftables connection: %s", err)
  1527. }
  1528. tables, err := conn.ListTables() // both v4 and v6
  1529. if err != nil {
  1530. logf("cleanup: list tables: %s", err)
  1531. }
  1532. for _, table := range tables {
  1533. // These table names were used briefly in 1.48.0.
  1534. if table.Name == "ts-filter" || table.Name == "ts-nat" {
  1535. conn.DelTable(table)
  1536. if err := conn.Flush(); err != nil {
  1537. logf("cleanup: flush delete table %s: %s", table.Name, err)
  1538. }
  1539. }
  1540. if table.Name == "filter" {
  1541. cleanupChain(logf, conn, table, "INPUT", chainNameInput)
  1542. cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
  1543. }
  1544. if table.Name == "nat" {
  1545. cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
  1546. }
  1547. }
  1548. }