fsm.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. package marionette
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/rand"
  8. "net"
  9. "os"
  10. "regexp"
  11. "strconv"
  12. "sync"
  13. "github.com/redjack/marionette/fte"
  14. "github.com/redjack/marionette/mar"
  15. "go.uber.org/zap"
  16. )
  17. var (
  18. // ErrNoTransitions is returned from FSM.Next() when no transitions can be found.
  19. ErrNoTransitions = errors.New("no transitions available")
  20. // ErrRetryTransition is returned from FSM.Next() when a transition should be reattempted.
  21. ErrRetryTransition = errors.New("retry transition")
  22. // ErrUUIDMismatch is returned when a cell is received from a different UUID.
  23. // This can occur when communicating with a peer using a different MAR document.
  24. ErrUUIDMismatch = errors.New("uuid mismatch")
  25. )
  26. // FSM represents an interface for the Marionette state machine.
  27. type FSM interface {
  28. io.Closer
  29. // Document & FSM identifiers.
  30. UUID() int
  31. SetInstanceID(int)
  32. InstanceID() int
  33. // Party & networking.
  34. Party() string
  35. Host() string
  36. Port() int
  37. // The current state in the FSM.
  38. State() string
  39. // Returns true if State() == 'dead'
  40. Dead() bool
  41. // Moves to the next available state.
  42. // Returns ErrNoTransition if there is no state to move to.
  43. Next(ctx context.Context) error
  44. // Moves through the entire state machine until it reaches 'dead' state.
  45. Execute(ctx context.Context) error
  46. // Restarts the FSM so it can be reused.
  47. Reset()
  48. // Returns an FTE cipher or DFA from the cache or creates a new one.
  49. Cipher(regex string, n int) (Cipher, error)
  50. DFA(regex string, msgLen int) (DFA, error)
  51. // Returns the network connection attached to the FSM.
  52. Conn() *BufferedConn
  53. // Listen opens a new listener to accept data and drains into the buffer.
  54. Listen() (int, error)
  55. // Returns the stream set attached to the FSM.
  56. StreamSet() *StreamSet
  57. // Sets and retrieves key/values from the FSM.
  58. SetVar(key string, value interface{})
  59. Var(key string) interface{}
  60. // Returns a copy of the FSM with a different format.
  61. Clone(doc *mar.Document) FSM
  62. Logger() *zap.Logger
  63. }
  64. // Ensure implementation implements interface.
  65. var _ FSM = &fsm{}
  66. // fsm is the default implementation of the FSM.
  67. type fsm struct {
  68. mu sync.Mutex
  69. doc *mar.Document // executing document
  70. host string // bind hostname
  71. party string // "client", "server"
  72. fteCache *fte.Cache
  73. conn *BufferedConn // connection to remote peer
  74. streamSet *StreamSet // multiplexing stream set
  75. listeners map[int]net.Listener // spawn() listeners
  76. closeFuncs []func() error // closers used by spawn()
  77. state string // current state
  78. stepN int // number of steps completed
  79. rand *rand.Rand // PRNG, seed shared by peer
  80. // Close management
  81. closed bool
  82. ctx context.Context
  83. cancel func()
  84. // Lookup of transitions by src state.
  85. transitions map[string][]*mar.Transition
  86. // Variable storage used by tg module.
  87. vars map[string]interface{}
  88. // Set by the first sender and used to seed PRNG.
  89. instanceID int
  90. }
  91. // NewFSM returns a new FSM. If party is the first sender then the instance id is set.
  92. func NewFSM(doc *mar.Document, host, party string, conn net.Conn, streamSet *StreamSet) FSM {
  93. fsm := &fsm{
  94. state: "start",
  95. vars: make(map[string]interface{}),
  96. doc: doc,
  97. host: host,
  98. party: party,
  99. fteCache: fte.NewCache(),
  100. conn: NewBufferedConn(conn, MaxCellLength),
  101. streamSet: streamSet,
  102. listeners: make(map[int]net.Listener),
  103. }
  104. fsm.ctx, fsm.cancel = context.WithCancel(context.TODO())
  105. fsm.buildTransitions()
  106. fsm.initFirstSender()
  107. return fsm
  108. }
  109. // buildTransitions caches a mapping of source to destination transition for the document.
  110. func (fsm *fsm) buildTransitions() {
  111. fsm.transitions = make(map[string][]*mar.Transition)
  112. for _, t := range fsm.doc.Transitions {
  113. fsm.transitions[t.Source] = append(fsm.transitions[t.Source], t)
  114. }
  115. }
  116. // initFirstSender generates an instance ID & seeds the PRNG if this party initiates the connection.
  117. func (fsm *fsm) initFirstSender() {
  118. if fsm.party != fsm.doc.FirstSender() {
  119. return
  120. }
  121. fsm.instanceID = int(rand.Int31())
  122. fsm.rand = rand.New(rand.NewSource(int64(fsm.instanceID)))
  123. }
  124. // Close closes the underlying connection & context.
  125. func (fsm *fsm) Close() error {
  126. fsm.mu.Lock()
  127. defer fsm.mu.Unlock()
  128. fsm.closed = true
  129. fsm.cancel()
  130. return fsm.Conn().Close()
  131. }
  132. // Closed returns true if FSM has been closed.
  133. func (fsm *fsm) Closed() bool {
  134. fsm.mu.Lock()
  135. defer fsm.mu.Unlock()
  136. return fsm.closed
  137. }
  138. // Reset resets the state and variable set.
  139. func (fsm *fsm) Reset() {
  140. fsm.state = "start"
  141. fsm.vars = make(map[string]interface{})
  142. for _, fn := range fsm.closeFuncs {
  143. if err := fn(); err != nil {
  144. fsm.Logger().Error("close error", zap.Error(err))
  145. }
  146. }
  147. fsm.closeFuncs = nil
  148. }
  149. // UUID returns the computed MAR document UUID.
  150. func (fsm *fsm) UUID() int { return fsm.doc.UUID }
  151. // InstanceID returns the ID for this specific FSM.
  152. func (fsm *fsm) InstanceID() int { return fsm.instanceID }
  153. // SetInstanceID sets the ID for the FSM.
  154. func (fsm *fsm) SetInstanceID(id int) { fsm.instanceID = id }
  155. // State returns the current state of the FSM.
  156. func (fsm *fsm) State() string { return fsm.state }
  157. // Conn returns the connection the FSM was initialized with.
  158. func (fsm *fsm) Conn() *BufferedConn { return fsm.conn }
  159. // StreamSet returns the stream set the FSM was initialized with.
  160. func (fsm *fsm) StreamSet() *StreamSet { return fsm.streamSet }
  161. // Host returns the hostname the FSM was initialized with.
  162. func (fsm *fsm) Host() string { return fsm.host }
  163. // Party returns "client" or "server" depending on who is initializing the FSM.
  164. func (fsm *fsm) Party() string { return fsm.party }
  165. // Port returns the port from the underlying document.
  166. // If port is a named port then it is looked up in the local variables.
  167. func (fsm *fsm) Port() int {
  168. // Use specified port, if numeric.
  169. if port, err := strconv.Atoi(fsm.doc.Port); err == nil {
  170. return port
  171. }
  172. // Otherwise lookup port set as a variable.
  173. if v := fsm.Var(fsm.doc.Port); v != nil {
  174. port, _ := v.(int)
  175. return port
  176. }
  177. return 0
  178. }
  179. // Dead returns true when the FSM is complete.
  180. func (fsm *fsm) Dead() bool { return fsm.state == "dead" }
  181. // Execute runs the the FSM to completion.
  182. func (fsm *fsm) Execute(ctx context.Context) error {
  183. // If no connection is passed in, create one.
  184. // This occurs when an FSM is spawned.
  185. if err := fsm.ensureConn(ctx); err != nil {
  186. return err
  187. }
  188. // Continually move to the next state until we reach the "dead" state.
  189. for !fsm.Dead() {
  190. // Transitions can request to retry if the instance ID is updated.
  191. // In this case, the PRNG is seeded and stepN steps are reprocessed w/ new PRNG.
  192. if err := fsm.Next(ctx); err == ErrRetryTransition {
  193. fsm.Logger().Debug("retry transition", zap.String("state", fsm.State()))
  194. continue
  195. } else if err != nil {
  196. return err
  197. }
  198. }
  199. return nil
  200. }
  201. // Next transitions to the next state in the executing MAR document..
  202. func (fsm *fsm) Next(ctx context.Context) (err error) {
  203. // Notify caller stream is closed if FSM has been closed.
  204. if fsm.Closed() {
  205. return ErrStreamClosed
  206. }
  207. // Generate a new PRNG once we have an instance ID.
  208. if err := fsm.init(); err != nil {
  209. return err
  210. }
  211. // If we have a successful transition, update our state info.
  212. // Exit if no transitions were successful.
  213. nextState, err := fsm.next(true)
  214. if err != nil {
  215. return err
  216. }
  217. // Track number of steps so they can be replayed once the instance ID is received.
  218. // This only occurs if FSM's party is not the first sender.
  219. fsm.stepN += 1
  220. fsm.state = nextState
  221. return nil
  222. }
  223. func (fsm *fsm) next(eval bool) (nextState string, err error) {
  224. // Find all possible transitions from the current state.
  225. transitions := mar.FilterTransitionsBySource(fsm.doc.Transitions, fsm.state)
  226. errorTransitions := mar.FilterErrorTransitions(transitions)
  227. // Then filter by PRNG (if available) or return all (if unavailable).
  228. transitions = mar.FilterNonErrorTransitions(transitions)
  229. transitions = mar.ChooseTransitions(transitions, fsm.rand)
  230. assert(len(transitions) > 0)
  231. // Add error transitions back in after selection.
  232. transitions = append(transitions, errorTransitions...)
  233. // Attempt each possible transition.
  234. for _, transition := range transitions {
  235. // If there's no action block then move to the next state.
  236. if transition.ActionBlock == "NULL" {
  237. return transition.Destination, nil
  238. }
  239. // Find all actions for this destination and current party.
  240. blk := fsm.doc.ActionBlock(transition.ActionBlock)
  241. if blk == nil {
  242. return "", fmt.Errorf("fsm.Next(): action block not found: %q", transition.ActionBlock)
  243. }
  244. actions := mar.FilterActionsByParty(blk.Actions, fsm.party)
  245. // Attempt to execute each action.
  246. if eval {
  247. if err := fsm.evalActions(actions); err != nil {
  248. return "", err
  249. }
  250. }
  251. return transition.Destination, nil
  252. }
  253. return "", nil
  254. }
  255. // init initializes the PRNG if we now have a instance id.
  256. func (fsm *fsm) init() (err error) {
  257. // Skip if already initialized or we don't have an instance ID yet.
  258. if fsm.rand != nil || fsm.instanceID == 0 {
  259. return nil
  260. }
  261. // Create new PRNG.
  262. fsm.rand = rand.New(rand.NewSource(int64(fsm.instanceID)))
  263. // Restart FSM from the beginning and iterate until the current step.
  264. fsm.state = "start"
  265. for i := 0; i < fsm.stepN; i++ {
  266. fsm.state, err = fsm.next(false)
  267. if err != nil {
  268. return err
  269. }
  270. assert(fsm.state != "")
  271. }
  272. return nil
  273. }
  274. // evalActions attempts to evaluate every action until one succeeds.
  275. func (fsm *fsm) evalActions(actions []*mar.Action) error {
  276. if len(actions) == 0 {
  277. return nil
  278. }
  279. for _, action := range actions {
  280. // If there is no matching regex then simply evaluate action.
  281. if action.Regex != "" {
  282. // Compile regex.
  283. re, err := regexp.Compile(action.Regex)
  284. if err != nil {
  285. return err
  286. }
  287. // Only evaluate action if buffer matches.
  288. buf, err := fsm.conn.Peek(-1, false)
  289. if err != nil {
  290. return err
  291. } else if !re.Match(buf) {
  292. continue
  293. }
  294. }
  295. fn := FindPlugin(action.Module, action.Method)
  296. if fn == nil {
  297. return fmt.Errorf("plugin not found: %s", action.Name())
  298. } else if err := fn(fsm.ctx, fsm, action.ArgValues()...); err != nil {
  299. return err
  300. }
  301. return nil
  302. }
  303. return ErrNoTransitions
  304. }
  305. // Var returns the variable value for a given key.
  306. func (fsm *fsm) Var(key string) interface{} {
  307. switch key {
  308. case "model_instance_id":
  309. return fsm.InstanceID
  310. case "model_uuid":
  311. return fsm.doc.UUID
  312. case "party":
  313. return fsm.party
  314. default:
  315. return fsm.vars[key]
  316. }
  317. }
  318. // SetVar sets the variable value for a given key.
  319. func (fsm *fsm) SetVar(key string, value interface{}) {
  320. fsm.vars[key] = value
  321. }
  322. // Cipher returns a cipher with the given settings.
  323. // If no cipher exists then a new one is created and returned.
  324. func (fsm *fsm) Cipher(regex string, n int) (Cipher, error) {
  325. return fsm.fteCache.Cipher(regex, n)
  326. }
  327. // DFA returns a DFA with the given settings.
  328. // If no DFA exists then a new one is created and returned.
  329. func (fsm *fsm) DFA(regex string, n int) (DFA, error) {
  330. return fsm.fteCache.DFA(regex, n)
  331. }
  332. // Listen opens a listener used by channel.bind(). Listener closed by Close().
  333. //
  334. // Port is chosen randomly unless MARIONETTE_CHANNEL_BIND_PORT environment variable is set.
  335. func (fsm *fsm) Listen() (port int, err error) {
  336. addr := fsm.host
  337. if s := os.Getenv("MARIONETTE_CHANNEL_BIND_PORT"); s != "" {
  338. addr = net.JoinHostPort(addr, s)
  339. }
  340. ln, err := net.Listen("tcp", addr)
  341. if err != nil {
  342. return 0, err
  343. }
  344. port = ln.Addr().(*net.TCPAddr).Port
  345. fsm.listeners[port] = ln
  346. fsm.closeFuncs = append(fsm.closeFuncs, ln.Close)
  347. return port, nil
  348. }
  349. // ensureConn ensures that the conn variable is set. Root FSMs are populated with
  350. // a connection during instantiation, however, spawned FSMs require new connections.
  351. //
  352. // For client parties, a new connection is dialed to the server.
  353. // For server parties, a listener is opened and it waits for the next accepted connection.
  354. func (fsm *fsm) ensureConn(ctx context.Context) error {
  355. if fsm.conn != nil {
  356. return nil
  357. }
  358. if fsm.party == PartyClient {
  359. return fsm.ensureClientConn(ctx)
  360. }
  361. return fsm.ensureServerConn(ctx)
  362. }
  363. // ensureClientConn dials a connection to the server. Connection closed on Close().
  364. func (fsm *fsm) ensureClientConn(ctx context.Context) error {
  365. conn, err := net.Dial(fsm.doc.Transport, net.JoinHostPort(fsm.host, strconv.Itoa(fsm.Port())))
  366. if err != nil {
  367. return err
  368. }
  369. fsm.conn = NewBufferedConn(conn, MaxCellLength)
  370. fsm.closeFuncs = append(fsm.closeFuncs, conn.Close)
  371. return nil
  372. }
  373. // ensureServerConn opens a listener and waits for the next connection.
  374. // Will reuse listener if previously spawned. Listener closed on Close().
  375. func (fsm *fsm) ensureServerConn(ctx context.Context) (err error) {
  376. ln := fsm.listeners[fsm.Port()]
  377. if ln == nil {
  378. if ln, err = net.Listen("tcp", net.JoinHostPort(fsm.host, strconv.Itoa(fsm.Port()))); err != nil {
  379. return err
  380. }
  381. fsm.listeners[fsm.Port()] = ln
  382. }
  383. conn, err := ln.Accept()
  384. if err != nil {
  385. return err
  386. }
  387. fsm.conn = NewBufferedConn(conn, MaxCellLength)
  388. fsm.closeFuncs = append(fsm.closeFuncs, conn.Close)
  389. return nil
  390. }
  391. // Clone returns a copy of f. Used when spawning new FSMs.
  392. func (f *fsm) Clone(doc *mar.Document) FSM {
  393. other := &fsm{
  394. state: "start",
  395. vars: make(map[string]interface{}),
  396. doc: doc,
  397. host: f.host,
  398. party: f.party,
  399. fteCache: f.fteCache,
  400. streamSet: f.streamSet,
  401. listeners: f.listeners,
  402. }
  403. other.buildTransitions()
  404. other.initFirstSender()
  405. other.vars = make(map[string]interface{})
  406. for k, v := range f.vars {
  407. other.vars[k] = v
  408. }
  409. return other
  410. }
  411. // Logger returns the logger for this FSM.
  412. func (fsm *fsm) Logger() *zap.Logger {
  413. if fsm.Closed() {
  414. return zap.NewNop()
  415. }
  416. return Logger.With(zap.String("party", fsm.party))
  417. }