The Problem with Switch-Based State Machines
When building state machines in Go, many developers reach for switch statements. While this works for simple cases, it quickly becomes unwieldy as your state machine grows. Each state transition requires scanning through multiple case blocks, and adding new states means touching existing code in multiple places.
Let me show you a cleaner approach: using functions as first-class values to represent states. This pattern leverages Go’s ability to treat functions as data, creating elegant, extensible state machines.
Traditional Switch-Based Approach
Here’s a typical implementation of a TCP connection state machine using switches:
type ConnectionState int
const (
Closed ConnectionState = iota
Listen
SynSent
SynReceived
Established
FinWait1
FinWait2
CloseWait
Closing
LastAck
TimeWait
)
type Connection struct {
state ConnectionState
}
func (c *Connection) HandleEvent(event string) error {
switch c.state {
case Closed:
switch event {
case "PASSIVE_OPEN":
c.state = Listen
case "ACTIVE_OPEN":
c.state = SynSent
default:
return fmt.Errorf("invalid event %s for state Closed", event)
}
case Listen:
switch event {
case "SYN":
c.state = SynReceived
case "CLOSE":
c.state = Closed
default:
return fmt.Errorf("invalid event %s for state Listen", event)
}
// ... many more case blocks
}
return nil
}
This approach has several issues:
- Nested switches are hard to read
- Adding a new state requires modifying the switch
- State-specific logic is scattered across case blocks
- No compile-time guarantee that all states handle all events
Function-State Machine Pattern
Here’s the same state machine using functions as states:
package main
import (
"fmt"
"log"
)
// StateFunc represents a state as a function
// It takes an event and returns the next state (or itself)
type StateFunc func(event string) (StateFunc, error)
// Connection represents a TCP connection
type Connection struct {
currentState StateFunc
name string
}
func NewConnection(name string) *Connection {
c := &Connection{name: name}
c.currentState = c.closedState
return c
}
// ProcessEvent handles an event and transitions to the next state
func (c *Connection) ProcessEvent(event string) error {
nextState, err := c.currentState(event)
if err != nil {
return err
}
c.currentState = nextState
return nil
}
// State functions - each state is a method that returns the next state
func (c *Connection) closedState(event string) (StateFunc, error) {
log.Printf("[%s] CLOSED: received event %s", c.name, event)
switch event {
case "PASSIVE_OPEN":
log.Printf("[%s] Transitioning: CLOSED -> LISTEN", c.name)
return c.listenState, nil
case "ACTIVE_OPEN":
log.Printf("[%s] Transitioning: CLOSED -> SYN_SENT", c.name)
return c.synSentState, nil
default:
return c.closedState, fmt.Errorf("invalid event %s for CLOSED state", event)
}
}
func (c *Connection) listenState(event string) (StateFunc, error) {
log.Printf("[%s] LISTEN: received event %s", c.name, event)
switch event {
case "SYN":
log.Printf("[%s] Transitioning: LISTEN -> SYN_RECEIVED", c.name)
return c.synReceivedState, nil
case "CLOSE":
log.Printf("[%s] Transitioning: LISTEN -> CLOSED", c.name)
return c.closedState, nil
default:
return c.listenState, fmt.Errorf("invalid event %s for LISTEN state", event)
}
}
func (c *Connection) synSentState(event string) (StateFunc, error) {
log.Printf("[%s] SYN_SENT: received event %s", c.name, event)
switch event {
case "SYN_ACK":
log.Printf("[%s] Transitioning: SYN_SENT -> ESTABLISHED", c.name)
return c.establishedState, nil
case "CLOSE":
log.Printf("[%s] Transitioning: SYN_SENT -> CLOSED", c.name)
return c.closedState, nil
default:
return c.synSentState, fmt.Errorf("invalid event %s for SYN_SENT state", event)
}
}
func (c *Connection) synReceivedState(event string) (StateFunc, error) {
log.Printf("[%s] SYN_RECEIVED: received event %s", c.name, event)
switch event {
case "ACK":
log.Printf("[%s] Transitioning: SYN_RECEIVED -> ESTABLISHED", c.name)
return c.establishedState, nil
case "CLOSE":
log.Printf("[%s] Transitioning: SYN_RECEIVED -> FIN_WAIT_1", c.name)
return c.finWait1State, nil
default:
return c.synReceivedState, fmt.Errorf("invalid event %s for SYN_RECEIVED state", event)
}
}
func (c *Connection) establishedState(event string) (StateFunc, error) {
log.Printf("[%s] ESTABLISHED: received event %s", c.name, event)
switch event {
case "CLOSE":
log.Printf("[%s] Transitioning: ESTABLISHED -> FIN_WAIT_1", c.name)
return c.finWait1State, nil
case "FIN":
log.Printf("[%s] Transitioning: ESTABLISHED -> CLOSE_WAIT", c.name)
return c.closeWaitState, nil
default:
return c.establishedState, fmt.Errorf("invalid event %s for ESTABLISHED state", event)
}
}
func (c *Connection) finWait1State(event string) (StateFunc, error) {
log.Printf("[%s] FIN_WAIT_1: received event %s", c.name, event)
switch event {
case "ACK":
log.Printf("[%s] Transitioning: FIN_WAIT_1 -> FIN_WAIT_2", c.name)
return c.finWait2State, nil
case "FIN":
log.Printf("[%s] Transitioning: FIN_WAIT_1 -> CLOSING", c.name)
return c.closingState, nil
default:
return c.finWait1State, fmt.Errorf("invalid event %s for FIN_WAIT_1 state", event)
}
}
func (c *Connection) finWait2State(event string) (StateFunc, error) {
log.Printf("[%s] FIN_WAIT_2: received event %s", c.name, event)
switch event {
case "FIN":
log.Printf("[%s] Transitioning: FIN_WAIT_2 -> TIME_WAIT", c.name)
return c.timeWaitState, nil
default:
return c.finWait2State, fmt.Errorf("invalid event %s for FIN_WAIT_2 state", event)
}
}
func (c *Connection) closeWaitState(event string) (StateFunc, error) {
log.Printf("[%s] CLOSE_WAIT: received event %s", c.name, event)
switch event {
case "CLOSE":
log.Printf("[%s] Transitioning: CLOSE_WAIT -> LAST_ACK", c.name)
return c.lastAckState, nil
default:
return c.closeWaitState, fmt.Errorf("invalid event %s for CLOSE_WAIT state", event)
}
}
func (c *Connection) closingState(event string) (StateFunc, error) {
log.Printf("[%s] CLOSING: received event %s", c.name, event)
switch event {
case "ACK":
log.Printf("[%s] Transitioning: CLOSING -> TIME_WAIT", c.name)
return c.timeWaitState, nil
default:
return c.closingState, fmt.Errorf("invalid event %s for CLOSING state", event)
}
}
func (c *Connection) lastAckState(event string) (StateFunc, error) {
log.Printf("[%s] LAST_ACK: received event %s", c.name, event)
switch event {
case "ACK":
log.Printf("[%s] Transitioning: LAST_ACK -> CLOSED", c.name)
return c.closedState, nil
default:
return c.lastAckState, fmt.Errorf("invalid event %s for LAST_ACK state", event)
}
}
func (c *Connection) timeWaitState(event string) (StateFunc, error) {
log.Printf("[%s] TIME_WAIT: received event %s", c.name, event)
switch event {
case "TIMEOUT":
log.Printf("[%s] Transitioning: TIME_WAIT -> CLOSED", c.name)
return c.closedState, nil
default:
return c.timeWaitState, fmt.Errorf("invalid event %s for TIME_WAIT state", event)
}
}
func main() {
fmt.Println("=== TCP Connection State Machine ===\n")
// Simulate a normal connection lifecycle
fmt.Println("--- Scenario 1: Normal Connection ---")
conn1 := NewConnection("Conn-1")
events := []string{"ACTIVE_OPEN", "SYN_ACK", "CLOSE", "ACK", "FIN", "TIMEOUT"}
for _, event := range events {
if err := conn1.ProcessEvent(event); err != nil {
log.Printf("[Conn-1] Error: %v\n", err)
}
fmt.Println()
}
// Simulate a passive connection
fmt.Println("\n--- Scenario 2: Passive Connection ---")
conn2 := NewConnection("Conn-2")
events2 := []string{"PASSIVE_OPEN", "SYN", "ACK", "FIN", "CLOSE", "ACK"}
for _, event := range events2 {
if err := conn2.ProcessEvent(event); err != nil {
log.Printf("[Conn-2] Error: %v\n", err)
}
fmt.Println()
}
// Test invalid transition
fmt.Println("\n--- Scenario 3: Invalid Transition ---")
conn3 := NewConnection("Conn-3")
if err := conn3.ProcessEvent("INVALID_EVENT"); err != nil {
log.Printf("[Conn-3] Expected error: %v\n", err)
}
}
Advanced Pattern: State Machine with Context
For more complex scenarios, you can pass context to state functions:
package main
import (
"fmt"
"time"
)
// Context carries data through state transitions
type Context struct {
Data map[string]interface{}
StartTime time.Time
EventCount int
}
// StateFuncWithContext takes context and event, returns next state
type StateFuncWithContext func(*Context, string) (StateFuncWithContext, error)
// TrafficLight implements a traffic light state machine
type TrafficLight struct {
currentState StateFuncWithContext
context *Context
name string
}
func NewTrafficLight(name string) *TrafficLight {
tl := &TrafficLight{
name: name,
context: &Context{
Data: make(map[string]interface{}),
StartTime: time.Now(),
},
}
tl.currentState = tl.redState
return tl
}
func (tl *TrafficLight) Tick() error {
tl.context.EventCount++
nextState, err := tl.currentState(tl.context, "TICK")
if err != nil {
return err
}
tl.currentState = nextState
return nil
}
func (tl *TrafficLight) Emergency() error {
nextState, err := tl.currentState(tl.context, "EMERGENCY")
if err != nil {
return err
}
tl.currentState = nextState
return nil
}
func (tl *TrafficLight) redState(ctx *Context, event string) (StateFuncWithContext, error) {
switch event {
case "TICK":
duration := ctx.Data["redDuration"]
if duration == nil {
ctx.Data["redDuration"] = 1
} else {
ctx.Data["redDuration"] = duration.(int) + 1
}
if duration != nil && duration.(int) >= 3 {
fmt.Printf("[%s] RED -> GREEN (waited %d ticks)\n", tl.name, duration.(int))
ctx.Data["redDuration"] = 0
return tl.greenState, nil
}
fmt.Printf("[%s] RED (tick %d)\n", tl.name, ctx.Data["redDuration"].(int))
return tl.redState, nil
case "EMERGENCY":
fmt.Printf("[%s] RED -> FLASHING (emergency)\n", tl.name)
return tl.flashingState, nil
default:
return tl.redState, fmt.Errorf("invalid event %s for RED state", event)
}
}
func (tl *TrafficLight) greenState(ctx *Context, event string) (StateFuncWithContext, error) {
switch event {
case "TICK":
duration := ctx.Data["greenDuration"]
if duration == nil {
ctx.Data["greenDuration"] = 1
} else {
ctx.Data["greenDuration"] = duration.(int) + 1
}
if duration != nil && duration.(int) >= 5 {
fmt.Printf("[%s] GREEN -> YELLOW (waited %d ticks)\n", tl.name, duration.(int))
ctx.Data["greenDuration"] = 0
return tl.yellowState, nil
}
fmt.Printf("[%s] GREEN (tick %d)\n", tl.name, ctx.Data["greenDuration"].(int))
return tl.greenState, nil
case "EMERGENCY":
fmt.Printf("[%s] GREEN -> FLASHING (emergency)\n", tl.name)
return tl.flashingState, nil
default:
return tl.greenState, fmt.Errorf("invalid event %s for GREEN state", event)
}
}
func (tl *TrafficLight) yellowState(ctx *Context, event string) (StateFuncWithContext, error) {
switch event {
case "TICK":
duration := ctx.Data["yellowDuration"]
if duration == nil {
ctx.Data["yellowDuration"] = 1
} else {
ctx.Data["yellowDuration"] = duration.(int) + 1
}
if duration != nil && duration.(int) >= 2 {
fmt.Printf("[%s] YELLOW -> RED (waited %d ticks)\n", tl.name, duration.(int))
ctx.Data["yellowDuration"] = 0
return tl.redState, nil
}
fmt.Printf("[%s] YELLOW (tick %d)\n", tl.name, ctx.Data["yellowDuration"].(int))
return tl.yellowState, nil
case "EMERGENCY":
fmt.Printf("[%s] YELLOW -> FLASHING (emergency)\n", tl.name)
return tl.flashingState, nil
default:
return tl.yellowState, fmt.Errorf("invalid event %s for YELLOW state", event)
}
}
func (tl *TrafficLight) flashingState(ctx *Context, event string) (StateFuncWithContext, error) {
switch event {
case "TICK":
fmt.Printf("[%s] FLASHING (emergency mode)\n", tl.name)
return tl.flashingState, nil
case "EMERGENCY":
fmt.Printf("[%s] FLASHING -> RED (emergency cleared)\n", tl.name)
return tl.redState, nil
default:
return tl.flashingState, fmt.Errorf("invalid event %s for FLASHING state", event)
}
}
func ExampleTrafficLight() {
light := NewTrafficLight("Intersection-A")
// Normal operation
fmt.Println("=== Normal Traffic Light Cycle ===")
for i := 0; i < 12; i++ {
light.Tick()
}
// Emergency
fmt.Println("\n=== Emergency Scenario ===")
light.Emergency()
light.Tick()
light.Tick()
light.Emergency() // Clear emergency
light.Tick()
}
Benefits of Function-State Machines
- Cleaner Code: Each state is isolated in its own function
- Type Safety: The compiler ensures states return valid state functions
- Easy to Extend: Adding new states doesn’t require modifying existing states
- Testable: Each state function can be tested independently
- Composable: State functions can be combined and reused
- Context-Aware: State functions can access and modify shared context
When to Use This Pattern
Function-state machines are ideal when:
- You have complex state transitions with different behaviors per state
- States need to carry context or data
- You want to avoid deeply nested switch statements
- You need to test states independently
- You’re building protocol handlers, parsers, or game AI
Comparison: Switch vs Function-State
| Aspect | Switch-Based | Function-State |
|---|---|---|
| Readability | Nested, hard to follow | Clean, isolated |
| Extensibility | Must modify switch | Add new function |
| Testing | Test entire switch | Test each state |
| Type Safety | Runtime errors | Compile-time checking |
| State Data | Global or passed | Context parameter |
| Complexity | Grows quadratically | Grows linearly |
Thank you
The function-state machine pattern is a powerful tool in Go’s functional programming toolkit. By treating states as functions that return their successors, you create elegant, maintainable state machines that scale beautifully with complexity.
Please drop an email at [email protected] if you would like to share any feedback or suggestions. Peace!