The Problem with Switch-Based State Machines

When building state machines in Go, many developers reach for switch statements. While this works for simple cases, it quickly becomes unwieldy as your state machine grows. Each state transition requires scanning through multiple case blocks, and adding new states means touching existing code in multiple places.

Let me show you a cleaner approach: using functions as first-class values to represent states. This pattern leverages Go’s ability to treat functions as data, creating elegant, extensible state machines.

Traditional Switch-Based Approach

Here’s a typical implementation of a TCP connection state machine using switches:

type ConnectionState int

const (
    Closed ConnectionState = iota
    Listen
    SynSent
    SynReceived
    Established
    FinWait1
    FinWait2
    CloseWait
    Closing
    LastAck
    TimeWait
)

type Connection struct {
    state ConnectionState
}

func (c *Connection) HandleEvent(event string) error {
    switch c.state {
    case Closed:
        switch event {
        case "PASSIVE_OPEN":
            c.state = Listen
        case "ACTIVE_OPEN":
            c.state = SynSent
        default:
            return fmt.Errorf("invalid event %s for state Closed", event)
        }
    case Listen:
        switch event {
        case "SYN":
            c.state = SynReceived
        case "CLOSE":
            c.state = Closed
        default:
            return fmt.Errorf("invalid event %s for state Listen", event)
        }
    // ... many more case blocks
    }
    return nil
}

This approach has several issues:

  • Nested switches are hard to read
  • Adding a new state requires modifying the switch
  • State-specific logic is scattered across case blocks
  • No compile-time guarantee that all states handle all events

Function-State Machine Pattern

Here’s the same state machine using functions as states:

package main

import (
    "fmt"
    "log"
)

// StateFunc represents a state as a function
// It takes an event and returns the next state (or itself)
type StateFunc func(event string) (StateFunc, error)

// Connection represents a TCP connection
type Connection struct {
    currentState StateFunc
    name         string
}

func NewConnection(name string) *Connection {
    c := &Connection{name: name}
    c.currentState = c.closedState
    return c
}

// ProcessEvent handles an event and transitions to the next state
func (c *Connection) ProcessEvent(event string) error {
    nextState, err := c.currentState(event)
    if err != nil {
        return err
    }
    c.currentState = nextState
    return nil
}

// State functions - each state is a method that returns the next state

func (c *Connection) closedState(event string) (StateFunc, error) {
    log.Printf("[%s] CLOSED: received event %s", c.name, event)

    switch event {
    case "PASSIVE_OPEN":
        log.Printf("[%s] Transitioning: CLOSED -> LISTEN", c.name)
        return c.listenState, nil
    case "ACTIVE_OPEN":
        log.Printf("[%s] Transitioning: CLOSED -> SYN_SENT", c.name)
        return c.synSentState, nil
    default:
        return c.closedState, fmt.Errorf("invalid event %s for CLOSED state", event)
    }
}

func (c *Connection) listenState(event string) (StateFunc, error) {
    log.Printf("[%s] LISTEN: received event %s", c.name, event)

    switch event {
    case "SYN":
        log.Printf("[%s] Transitioning: LISTEN -> SYN_RECEIVED", c.name)
        return c.synReceivedState, nil
    case "CLOSE":
        log.Printf("[%s] Transitioning: LISTEN -> CLOSED", c.name)
        return c.closedState, nil
    default:
        return c.listenState, fmt.Errorf("invalid event %s for LISTEN state", event)
    }
}

func (c *Connection) synSentState(event string) (StateFunc, error) {
    log.Printf("[%s] SYN_SENT: received event %s", c.name, event)

    switch event {
    case "SYN_ACK":
        log.Printf("[%s] Transitioning: SYN_SENT -> ESTABLISHED", c.name)
        return c.establishedState, nil
    case "CLOSE":
        log.Printf("[%s] Transitioning: SYN_SENT -> CLOSED", c.name)
        return c.closedState, nil
    default:
        return c.synSentState, fmt.Errorf("invalid event %s for SYN_SENT state", event)
    }
}

func (c *Connection) synReceivedState(event string) (StateFunc, error) {
    log.Printf("[%s] SYN_RECEIVED: received event %s", c.name, event)

    switch event {
    case "ACK":
        log.Printf("[%s] Transitioning: SYN_RECEIVED -> ESTABLISHED", c.name)
        return c.establishedState, nil
    case "CLOSE":
        log.Printf("[%s] Transitioning: SYN_RECEIVED -> FIN_WAIT_1", c.name)
        return c.finWait1State, nil
    default:
        return c.synReceivedState, fmt.Errorf("invalid event %s for SYN_RECEIVED state", event)
    }
}

func (c *Connection) establishedState(event string) (StateFunc, error) {
    log.Printf("[%s] ESTABLISHED: received event %s", c.name, event)

    switch event {
    case "CLOSE":
        log.Printf("[%s] Transitioning: ESTABLISHED -> FIN_WAIT_1", c.name)
        return c.finWait1State, nil
    case "FIN":
        log.Printf("[%s] Transitioning: ESTABLISHED -> CLOSE_WAIT", c.name)
        return c.closeWaitState, nil
    default:
        return c.establishedState, fmt.Errorf("invalid event %s for ESTABLISHED state", event)
    }
}

func (c *Connection) finWait1State(event string) (StateFunc, error) {
    log.Printf("[%s] FIN_WAIT_1: received event %s", c.name, event)

    switch event {
    case "ACK":
        log.Printf("[%s] Transitioning: FIN_WAIT_1 -> FIN_WAIT_2", c.name)
        return c.finWait2State, nil
    case "FIN":
        log.Printf("[%s] Transitioning: FIN_WAIT_1 -> CLOSING", c.name)
        return c.closingState, nil
    default:
        return c.finWait1State, fmt.Errorf("invalid event %s for FIN_WAIT_1 state", event)
    }
}

func (c *Connection) finWait2State(event string) (StateFunc, error) {
    log.Printf("[%s] FIN_WAIT_2: received event %s", c.name, event)

    switch event {
    case "FIN":
        log.Printf("[%s] Transitioning: FIN_WAIT_2 -> TIME_WAIT", c.name)
        return c.timeWaitState, nil
    default:
        return c.finWait2State, fmt.Errorf("invalid event %s for FIN_WAIT_2 state", event)
    }
}

func (c *Connection) closeWaitState(event string) (StateFunc, error) {
    log.Printf("[%s] CLOSE_WAIT: received event %s", c.name, event)

    switch event {
    case "CLOSE":
        log.Printf("[%s] Transitioning: CLOSE_WAIT -> LAST_ACK", c.name)
        return c.lastAckState, nil
    default:
        return c.closeWaitState, fmt.Errorf("invalid event %s for CLOSE_WAIT state", event)
    }
}

func (c *Connection) closingState(event string) (StateFunc, error) {
    log.Printf("[%s] CLOSING: received event %s", c.name, event)

    switch event {
    case "ACK":
        log.Printf("[%s] Transitioning: CLOSING -> TIME_WAIT", c.name)
        return c.timeWaitState, nil
    default:
        return c.closingState, fmt.Errorf("invalid event %s for CLOSING state", event)
    }
}

func (c *Connection) lastAckState(event string) (StateFunc, error) {
    log.Printf("[%s] LAST_ACK: received event %s", c.name, event)

    switch event {
    case "ACK":
        log.Printf("[%s] Transitioning: LAST_ACK -> CLOSED", c.name)
        return c.closedState, nil
    default:
        return c.lastAckState, fmt.Errorf("invalid event %s for LAST_ACK state", event)
    }
}

func (c *Connection) timeWaitState(event string) (StateFunc, error) {
    log.Printf("[%s] TIME_WAIT: received event %s", c.name, event)

    switch event {
    case "TIMEOUT":
        log.Printf("[%s] Transitioning: TIME_WAIT -> CLOSED", c.name)
        return c.closedState, nil
    default:
        return c.timeWaitState, fmt.Errorf("invalid event %s for TIME_WAIT state", event)
    }
}

func main() {
    fmt.Println("=== TCP Connection State Machine ===\n")

    // Simulate a normal connection lifecycle
    fmt.Println("--- Scenario 1: Normal Connection ---")
    conn1 := NewConnection("Conn-1")

    events := []string{"ACTIVE_OPEN", "SYN_ACK", "CLOSE", "ACK", "FIN", "TIMEOUT"}
    for _, event := range events {
        if err := conn1.ProcessEvent(event); err != nil {
            log.Printf("[Conn-1] Error: %v\n", err)
        }
        fmt.Println()
    }

    // Simulate a passive connection
    fmt.Println("\n--- Scenario 2: Passive Connection ---")
    conn2 := NewConnection("Conn-2")

    events2 := []string{"PASSIVE_OPEN", "SYN", "ACK", "FIN", "CLOSE", "ACK"}
    for _, event := range events2 {
        if err := conn2.ProcessEvent(event); err != nil {
            log.Printf("[Conn-2] Error: %v\n", err)
        }
        fmt.Println()
    }

    // Test invalid transition
    fmt.Println("\n--- Scenario 3: Invalid Transition ---")
    conn3 := NewConnection("Conn-3")
    if err := conn3.ProcessEvent("INVALID_EVENT"); err != nil {
        log.Printf("[Conn-3] Expected error: %v\n", err)
    }
}
graph LR CLOSED[CLOSED]:::lightBlue LISTEN[LISTEN]:::lightGreen SYN_SENT[SYN_SENT]:::lightYellow SYN_RECEIVED[SYN_RECEIVED]:::lightYellow ESTABLISHED[ESTABLISHED]:::lightGreen FIN_WAIT_1[FIN_WAIT_1]:::lightOrange FIN_WAIT_2[FIN_WAIT_2]:::lightOrange CLOSE_WAIT[CLOSE_WAIT]:::lightOrange CLOSING[CLOSING]:::lightOrange LAST_ACK[LAST_ACK]:::lightOrange TIME_WAIT[TIME_WAIT]:::lightPurple CLOSED -->|PASSIVE_OPEN| LISTEN CLOSED -->|ACTIVE_OPEN| SYN_SENT LISTEN -->|SYN| SYN_RECEIVED LISTEN -->|CLOSE| CLOSED SYN_SENT -->|SYN_ACK| ESTABLISHED SYN_SENT -->|CLOSE| CLOSED SYN_RECEIVED -->|ACK| ESTABLISHED SYN_RECEIVED -->|CLOSE| FIN_WAIT_1 ESTABLISHED -->|CLOSE| FIN_WAIT_1 ESTABLISHED -->|FIN| CLOSE_WAIT FIN_WAIT_1 -->|ACK| FIN_WAIT_2 FIN_WAIT_1 -->|FIN| CLOSING FIN_WAIT_2 -->|FIN| TIME_WAIT CLOSE_WAIT -->|CLOSE| LAST_ACK CLOSING -->|ACK| TIME_WAIT LAST_ACK -->|ACK| CLOSED TIME_WAIT -->|TIMEOUT| CLOSED classDef lightBlue fill:#87CEEB,stroke:#4682B4,stroke-width:2px,color:#000 classDef lightGreen fill:#90EE90,stroke:#228B22,stroke-width:2px,color:#000 classDef lightYellow fill:#FFFFE0,stroke:#FFD700,stroke-width:2px,color:#000 classDef lightOrange fill:#FFDAB9,stroke:#FF8C00,stroke-width:2px,color:#000 classDef lightPurple fill:#DDA0DD,stroke:#9370DB,stroke-width:2px,color:#000

Advanced Pattern: State Machine with Context

For more complex scenarios, you can pass context to state functions:

package main

import (
    "fmt"
    "time"
)

// Context carries data through state transitions
type Context struct {
    Data       map[string]interface{}
    StartTime  time.Time
    EventCount int
}

// StateFuncWithContext takes context and event, returns next state
type StateFuncWithContext func(*Context, string) (StateFuncWithContext, error)

// TrafficLight implements a traffic light state machine
type TrafficLight struct {
    currentState StateFuncWithContext
    context      *Context
    name         string
}

func NewTrafficLight(name string) *TrafficLight {
    tl := &TrafficLight{
        name: name,
        context: &Context{
            Data:      make(map[string]interface{}),
            StartTime: time.Now(),
        },
    }
    tl.currentState = tl.redState
    return tl
}

func (tl *TrafficLight) Tick() error {
    tl.context.EventCount++
    nextState, err := tl.currentState(tl.context, "TICK")
    if err != nil {
        return err
    }
    tl.currentState = nextState
    return nil
}

func (tl *TrafficLight) Emergency() error {
    nextState, err := tl.currentState(tl.context, "EMERGENCY")
    if err != nil {
        return err
    }
    tl.currentState = nextState
    return nil
}

func (tl *TrafficLight) redState(ctx *Context, event string) (StateFuncWithContext, error) {
    switch event {
    case "TICK":
        duration := ctx.Data["redDuration"]
        if duration == nil {
            ctx.Data["redDuration"] = 1
        } else {
            ctx.Data["redDuration"] = duration.(int) + 1
        }

        if duration != nil && duration.(int) >= 3 {
            fmt.Printf("[%s] RED -> GREEN (waited %d ticks)\n", tl.name, duration.(int))
            ctx.Data["redDuration"] = 0
            return tl.greenState, nil
        }
        fmt.Printf("[%s] RED (tick %d)\n", tl.name, ctx.Data["redDuration"].(int))
        return tl.redState, nil

    case "EMERGENCY":
        fmt.Printf("[%s] RED -> FLASHING (emergency)\n", tl.name)
        return tl.flashingState, nil

    default:
        return tl.redState, fmt.Errorf("invalid event %s for RED state", event)
    }
}

func (tl *TrafficLight) greenState(ctx *Context, event string) (StateFuncWithContext, error) {
    switch event {
    case "TICK":
        duration := ctx.Data["greenDuration"]
        if duration == nil {
            ctx.Data["greenDuration"] = 1
        } else {
            ctx.Data["greenDuration"] = duration.(int) + 1
        }

        if duration != nil && duration.(int) >= 5 {
            fmt.Printf("[%s] GREEN -> YELLOW (waited %d ticks)\n", tl.name, duration.(int))
            ctx.Data["greenDuration"] = 0
            return tl.yellowState, nil
        }
        fmt.Printf("[%s] GREEN (tick %d)\n", tl.name, ctx.Data["greenDuration"].(int))
        return tl.greenState, nil

    case "EMERGENCY":
        fmt.Printf("[%s] GREEN -> FLASHING (emergency)\n", tl.name)
        return tl.flashingState, nil

    default:
        return tl.greenState, fmt.Errorf("invalid event %s for GREEN state", event)
    }
}

func (tl *TrafficLight) yellowState(ctx *Context, event string) (StateFuncWithContext, error) {
    switch event {
    case "TICK":
        duration := ctx.Data["yellowDuration"]
        if duration == nil {
            ctx.Data["yellowDuration"] = 1
        } else {
            ctx.Data["yellowDuration"] = duration.(int) + 1
        }

        if duration != nil && duration.(int) >= 2 {
            fmt.Printf("[%s] YELLOW -> RED (waited %d ticks)\n", tl.name, duration.(int))
            ctx.Data["yellowDuration"] = 0
            return tl.redState, nil
        }
        fmt.Printf("[%s] YELLOW (tick %d)\n", tl.name, ctx.Data["yellowDuration"].(int))
        return tl.yellowState, nil

    case "EMERGENCY":
        fmt.Printf("[%s] YELLOW -> FLASHING (emergency)\n", tl.name)
        return tl.flashingState, nil

    default:
        return tl.yellowState, fmt.Errorf("invalid event %s for YELLOW state", event)
    }
}

func (tl *TrafficLight) flashingState(ctx *Context, event string) (StateFuncWithContext, error) {
    switch event {
    case "TICK":
        fmt.Printf("[%s] FLASHING (emergency mode)\n", tl.name)
        return tl.flashingState, nil

    case "EMERGENCY":
        fmt.Printf("[%s] FLASHING -> RED (emergency cleared)\n", tl.name)
        return tl.redState, nil

    default:
        return tl.flashingState, fmt.Errorf("invalid event %s for FLASHING state", event)
    }
}

func ExampleTrafficLight() {
    light := NewTrafficLight("Intersection-A")

    // Normal operation
    fmt.Println("=== Normal Traffic Light Cycle ===")
    for i := 0; i < 12; i++ {
        light.Tick()
    }

    // Emergency
    fmt.Println("\n=== Emergency Scenario ===")
    light.Emergency()
    light.Tick()
    light.Tick()
    light.Emergency() // Clear emergency
    light.Tick()
}

Benefits of Function-State Machines

  1. Cleaner Code: Each state is isolated in its own function
  2. Type Safety: The compiler ensures states return valid state functions
  3. Easy to Extend: Adding new states doesn’t require modifying existing states
  4. Testable: Each state function can be tested independently
  5. Composable: State functions can be combined and reused
  6. Context-Aware: State functions can access and modify shared context

When to Use This Pattern

Function-state machines are ideal when:

  • You have complex state transitions with different behaviors per state
  • States need to carry context or data
  • You want to avoid deeply nested switch statements
  • You need to test states independently
  • You’re building protocol handlers, parsers, or game AI

Comparison: Switch vs Function-State

Aspect Switch-Based Function-State
Readability Nested, hard to follow Clean, isolated
Extensibility Must modify switch Add new function
Testing Test entire switch Test each state
Type Safety Runtime errors Compile-time checking
State Data Global or passed Context parameter
Complexity Grows quadratically Grows linearly

Thank you

The function-state machine pattern is a powerful tool in Go’s functional programming toolkit. By treating states as functions that return their successors, you create elegant, maintainable state machines that scale beautifully with complexity.

Please drop an email at [email protected] if you would like to share any feedback or suggestions. Peace!