Go Concurrency Patterns Series: ← Context Pattern | Series Overview | Rate Limiter β†’


What is the Circuit Breaker Pattern?

The Circuit Breaker pattern prevents cascading failures in distributed systems by monitoring for failures and temporarily stopping calls to failing services. Like an electrical circuit breaker, it “trips” when failures exceed a threshold, giving the failing service time to recover.

States:

  • Closed: Normal operation, requests pass through
  • Open: Failing fast, requests are rejected immediately
  • Half-Open: Testing if service has recovered

Real-World Use Cases

  • Microservices: Prevent cascade failures between services
  • Database Connections: Handle database outages gracefully
  • External APIs: Deal with third-party service failures
  • Payment Processing: Handle payment gateway issues
  • File Systems: Manage disk I/O failures
  • Network Operations: Handle network partitions

Basic Circuit Breaker Implementation

package main

import (
    "context"
    "errors"
    "fmt"
    "sync"
    "time"
)

// State represents the circuit breaker state
type State int

const (
    StateClosed State = iota
    StateOpen
    StateHalfOpen
)

func (s State) String() string {
    switch s {
    case StateClosed:
        return "CLOSED"
    case StateOpen:
        return "OPEN"
    case StateHalfOpen:
        return "HALF_OPEN"
    default:
        return "UNKNOWN"
    }
}

// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
    mu                sync.RWMutex
    state             State
    failureCount      int
    successCount      int
    lastFailureTime   time.Time
    nextAttemptTime   time.Time
    
    // Configuration
    maxFailures       int
    resetTimeout      time.Duration
    halfOpenMaxCalls  int
}

// Config holds circuit breaker configuration
type Config struct {
    MaxFailures      int           // Number of failures before opening
    ResetTimeout     time.Duration // Time to wait before trying half-open
    HalfOpenMaxCalls int           // Max calls allowed in half-open state
}

// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(config Config) *CircuitBreaker {
    return &CircuitBreaker{
        state:            StateClosed,
        maxFailures:      config.MaxFailures,
        resetTimeout:     config.ResetTimeout,
        halfOpenMaxCalls: config.HalfOpenMaxCalls,
    }
}

// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
    if !cb.allowRequest() {
        return errors.New("circuit breaker is open")
    }
    
    err := fn()
    cb.recordResult(err)
    return err
}

// allowRequest determines if a request should be allowed
func (cb *CircuitBreaker) allowRequest() bool {
    cb.mu.Lock()
    defer cb.mu.Unlock()
    
    now := time.Now()
    
    switch cb.state {
    case StateClosed:
        return true
    case StateOpen:
        if now.After(cb.nextAttemptTime) {
            cb.state = StateHalfOpen
            cb.successCount = 0
            cb.failureCount = 0
            fmt.Printf("Circuit breaker transitioning to HALF_OPEN\n")
            return true
        }
        return false
    case StateHalfOpen:
        return cb.successCount + cb.failureCount < cb.halfOpenMaxCalls
    default:
        return false
    }
}

// recordResult records the result of a function call
func (cb *CircuitBreaker) recordResult(err error) {
    cb.mu.Lock()
    defer cb.mu.Unlock()
    
    if err != nil {
        cb.onFailure()
    } else {
        cb.onSuccess()
    }
}

// onFailure handles a failure
func (cb *CircuitBreaker) onFailure() {
    cb.failureCount++
    cb.lastFailureTime = time.Now()
    
    switch cb.state {
    case StateClosed:
        if cb.failureCount >= cb.maxFailures {
            cb.state = StateOpen
            cb.nextAttemptTime = time.Now().Add(cb.resetTimeout)
            fmt.Printf("Circuit breaker OPENED after %d failures\n", cb.failureCount)
        }
    case StateHalfOpen:
        cb.state = StateOpen
        cb.nextAttemptTime = time.Now().Add(cb.resetTimeout)
        fmt.Printf("Circuit breaker returned to OPEN from HALF_OPEN\n")
    }
}

// onSuccess handles a success
func (cb *CircuitBreaker) onSuccess() {
    switch cb.state {
    case StateClosed:
        cb.failureCount = 0
    case StateHalfOpen:
        cb.successCount++
        if cb.successCount >= cb.halfOpenMaxCalls {
            cb.state = StateClosed
            cb.failureCount = 0
            fmt.Printf("Circuit breaker CLOSED after successful recovery\n")
        }
    }
}

// GetState returns the current state
func (cb *CircuitBreaker) GetState() State {
    cb.mu.RLock()
    defer cb.mu.RUnlock()
    return cb.state
}

// GetStats returns current statistics
func (cb *CircuitBreaker) GetStats() (State, int, int) {
    cb.mu.RLock()
    defer cb.mu.RUnlock()
    return cb.state, cb.failureCount, cb.successCount
}

// simulateService simulates a service that might fail
func simulateService(shouldFail bool, delay time.Duration) func() error {
    return func() error {
        time.Sleep(delay)
        if shouldFail {
            return errors.New("service failure")
        }
        return nil
    }
}

func main() {
    config := Config{
        MaxFailures:      3,
        ResetTimeout:     2 * time.Second,
        HalfOpenMaxCalls: 2,
    }
    
    cb := NewCircuitBreaker(config)
    
    // Test scenario: failures followed by recovery
    scenarios := []struct {
        name      string
        shouldFail bool
        delay     time.Duration
    }{
        {"Success 1", false, 100 * time.Millisecond},
        {"Success 2", false, 100 * time.Millisecond},
        {"Failure 1", true, 100 * time.Millisecond},
        {"Failure 2", true, 100 * time.Millisecond},
        {"Failure 3", true, 100 * time.Millisecond}, // Should open circuit
        {"Blocked 1", false, 100 * time.Millisecond}, // Should be blocked
        {"Blocked 2", false, 100 * time.Millisecond}, // Should be blocked
    }
    
    for i, scenario := range scenarios {
        fmt.Printf("\n--- Test %d: %s ---\n", i+1, scenario.name)
        
        err := cb.Execute(simulateService(scenario.shouldFail, scenario.delay))
        state, failures, successes := cb.GetStats()
        
        if err != nil {
            fmt.Printf("Result: ERROR - %v\n", err)
        } else {
            fmt.Printf("Result: SUCCESS\n")
        }
        
        fmt.Printf("State: %s, Failures: %d, Successes: %d\n", 
            state, failures, successes)
        
        time.Sleep(100 * time.Millisecond)
    }
    
    // Wait for reset timeout and test recovery
    fmt.Printf("\n--- Waiting for reset timeout (%v) ---\n", config.ResetTimeout)
    time.Sleep(config.ResetTimeout + 100*time.Millisecond)
    
    // Test recovery
    recoveryTests := []struct {
        name      string
        shouldFail bool
    }{
        {"Recovery 1", false}, // Should succeed and move to half-open
        {"Recovery 2", false}, // Should succeed and close circuit
        {"Success after recovery", false},
    }
    
    for i, test := range recoveryTests {
        fmt.Printf("\n--- Recovery Test %d: %s ---\n", i+1, test.name)
        
        err := cb.Execute(simulateService(test.shouldFail, 100*time.Millisecond))
        state, failures, successes := cb.GetStats()
        
        if err != nil {
            fmt.Printf("Result: ERROR - %v\n", err)
        } else {
            fmt.Printf("Result: SUCCESS\n")
        }
        
        fmt.Printf("State: %s, Failures: %d, Successes: %d\n", 
            state, failures, successes)
    }
}

Advanced Circuit Breaker with Metrics

package main

import (
    "context"
    "fmt"
    "sync"
    "sync/atomic"
    "time"
)

// Metrics tracks circuit breaker statistics
type Metrics struct {
    totalRequests    int64
    successfulCalls  int64
    failedCalls      int64
    rejectedCalls    int64
    timeouts         int64
    stateChanges     int64
}

// AdvancedCircuitBreaker with comprehensive metrics and monitoring
type AdvancedCircuitBreaker struct {
    mu                sync.RWMutex
    state             State
    failureCount      int
    successCount      int
    lastFailureTime   time.Time
    nextAttemptTime   time.Time
    stateChangeTime   time.Time
    
    // Configuration
    maxFailures       int
    resetTimeout      time.Duration
    halfOpenMaxCalls  int
    callTimeout       time.Duration
    
    // Metrics
    metrics           *Metrics
    
    // Monitoring
    onStateChange     func(from, to State)
}

// AdvancedConfig holds advanced circuit breaker configuration
type AdvancedConfig struct {
    MaxFailures      int
    ResetTimeout     time.Duration
    HalfOpenMaxCalls int
    CallTimeout      time.Duration
    OnStateChange    func(from, to State)
}

// NewAdvancedCircuitBreaker creates a new advanced circuit breaker
func NewAdvancedCircuitBreaker(config AdvancedConfig) *AdvancedCircuitBreaker {
    return &AdvancedCircuitBreaker{
        state:            StateClosed,
        maxFailures:      config.MaxFailures,
        resetTimeout:     config.ResetTimeout,
        halfOpenMaxCalls: config.HalfOpenMaxCalls,
        callTimeout:      config.CallTimeout,
        metrics:          &Metrics{},
        onStateChange:    config.OnStateChange,
        stateChangeTime:  time.Now(),
    }
}

// ExecuteWithContext runs function with context and timeout
func (acb *AdvancedCircuitBreaker) ExecuteWithContext(ctx context.Context, fn func(context.Context) error) error {
    atomic.AddInt64(&acb.metrics.totalRequests, 1)
    
    if !acb.allowRequest() {
        atomic.AddInt64(&acb.metrics.rejectedCalls, 1)
        return fmt.Errorf("circuit breaker is %s", acb.GetState())
    }
    
    // Create context with timeout if specified
    if acb.callTimeout > 0 {
        var cancel context.CancelFunc
        ctx, cancel = context.WithTimeout(ctx, acb.callTimeout)
        defer cancel()
    }
    
    // Execute with timeout monitoring
    done := make(chan error, 1)
    go func() {
        done <- fn(ctx)
    }()
    
    select {
    case err := <-done:
        acb.recordResult(err)
        return err
    case <-ctx.Done():
        atomic.AddInt64(&acb.metrics.timeouts, 1)
        acb.recordResult(ctx.Err())
        return ctx.Err()
    }
}

// allowRequest determines if a request should be allowed
func (acb *AdvancedCircuitBreaker) allowRequest() bool {
    acb.mu.Lock()
    defer acb.mu.Unlock()
    
    now := time.Now()
    
    switch acb.state {
    case StateClosed:
        return true
    case StateOpen:
        if now.After(acb.nextAttemptTime) {
            acb.changeState(StateHalfOpen)
            acb.successCount = 0
            acb.failureCount = 0
            return true
        }
        return false
    case StateHalfOpen:
        return acb.successCount + acb.failureCount < acb.halfOpenMaxCalls
    default:
        return false
    }
}

// recordResult records the result of a function call
func (acb *AdvancedCircuitBreaker) recordResult(err error) {
    acb.mu.Lock()
    defer acb.mu.Unlock()
    
    if err != nil {
        atomic.AddInt64(&acb.metrics.failedCalls, 1)
        acb.onFailure()
    } else {
        atomic.AddInt64(&acb.metrics.successfulCalls, 1)
        acb.onSuccess()
    }
}

// changeState changes the circuit breaker state
func (acb *AdvancedCircuitBreaker) changeState(newState State) {
    if acb.state != newState {
        oldState := acb.state
        acb.state = newState
        acb.stateChangeTime = time.Now()
        atomic.AddInt64(&acb.metrics.stateChanges, 1)
        
        if acb.onStateChange != nil {
            go acb.onStateChange(oldState, newState)
        }
    }
}

// onFailure handles a failure
func (acb *AdvancedCircuitBreaker) onFailure() {
    acb.failureCount++
    acb.lastFailureTime = time.Now()
    
    switch acb.state {
    case StateClosed:
        if acb.failureCount >= acb.maxFailures {
            acb.changeState(StateOpen)
            acb.nextAttemptTime = time.Now().Add(acb.resetTimeout)
        }
    case StateHalfOpen:
        acb.changeState(StateOpen)
        acb.nextAttemptTime = time.Now().Add(acb.resetTimeout)
    }
}

// onSuccess handles a success
func (acb *AdvancedCircuitBreaker) onSuccess() {
    switch acb.state {
    case StateClosed:
        acb.failureCount = 0
    case StateHalfOpen:
        acb.successCount++
        if acb.successCount >= acb.halfOpenMaxCalls {
            acb.changeState(StateClosed)
            acb.failureCount = 0
        }
    }
}

// GetMetrics returns current metrics
func (acb *AdvancedCircuitBreaker) GetMetrics() Metrics {
    return Metrics{
        totalRequests:   atomic.LoadInt64(&acb.metrics.totalRequests),
        successfulCalls: atomic.LoadInt64(&acb.metrics.successfulCalls),
        failedCalls:     atomic.LoadInt64(&acb.metrics.failedCalls),
        rejectedCalls:   atomic.LoadInt64(&acb.metrics.rejectedCalls),
        timeouts:        atomic.LoadInt64(&acb.metrics.timeouts),
        stateChanges:    atomic.LoadInt64(&acb.metrics.stateChanges),
    }
}

// GetState returns current state
func (acb *AdvancedCircuitBreaker) GetState() State {
    acb.mu.RLock()
    defer acb.mu.RUnlock()
    return acb.state
}

// HealthCheck returns health information
func (acb *AdvancedCircuitBreaker) HealthCheck() map[string]interface{} {
    acb.mu.RLock()
    defer acb.mu.RUnlock()
    
    metrics := acb.GetMetrics()
    
    var successRate float64
    if metrics.totalRequests > 0 {
        successRate = float64(metrics.successfulCalls) / float64(metrics.totalRequests) * 100
    }
    
    return map[string]interface{}{
        "state":              acb.state.String(),
        "failure_count":      acb.failureCount,
        "success_count":      acb.successCount,
        "last_failure_time":  acb.lastFailureTime,
        "state_change_time":  acb.stateChangeTime,
        "next_attempt_time":  acb.nextAttemptTime,
        "total_requests":     metrics.totalRequests,
        "successful_calls":   metrics.successfulCalls,
        "failed_calls":       metrics.failedCalls,
        "rejected_calls":     metrics.rejectedCalls,
        "timeouts":           metrics.timeouts,
        "state_changes":      metrics.stateChanges,
        "success_rate":       fmt.Sprintf("%.2f%%", successRate),
    }
}

// Service simulation
type ExternalService struct {
    failureRate float64
    latency     time.Duration
}

func (es *ExternalService) Call(ctx context.Context, data string) error {
    // Simulate latency
    select {
    case <-time.After(es.latency):
    case <-ctx.Done():
        return ctx.Err()
    }
    
    // Simulate random failures
    if time.Now().UnixNano()%100 < int64(es.failureRate*100) {
        return fmt.Errorf("service failure for data: %s", data)
    }
    
    return nil
}

func main() {
    // Create service that fails 30% of the time
    service := &ExternalService{
        failureRate: 0.3,
        latency:     100 * time.Millisecond,
    }
    
    config := AdvancedConfig{
        MaxFailures:      3,
        ResetTimeout:     2 * time.Second,
        HalfOpenMaxCalls: 2,
        CallTimeout:      500 * time.Millisecond,
        OnStateChange: func(from, to State) {
            fmt.Printf("πŸ”„ Circuit breaker state changed: %s -> %s\n", from, to)
        },
    }
    
    cb := NewAdvancedCircuitBreaker(config)
    
    // Monitor circuit breaker health
    go func() {
        ticker := time.NewTicker(1 * time.Second)
        defer ticker.Stop()
        
        for range ticker.C {
            health := cb.HealthCheck()
            fmt.Printf("πŸ“Š Health: State=%s, Success Rate=%s, Total=%d, Failed=%d, Rejected=%d\n",
                health["state"], health["success_rate"], 
                health["total_requests"], health["failed_calls"], health["rejected_calls"])
        }
    }()
    
    // Simulate load
    var wg sync.WaitGroup
    for i := 0; i < 50; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            
            ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
            defer cancel()
            
            err := cb.ExecuteWithContext(ctx, func(ctx context.Context) error {
                return service.Call(ctx, fmt.Sprintf("request-%d", id))
            })
            
            if err != nil {
                fmt.Printf("❌ Request %d failed: %v\n", id, err)
            } else {
                fmt.Printf("βœ… Request %d succeeded\n", id)
            }
            
            time.Sleep(200 * time.Millisecond)
        }(i)
    }
    
    wg.Wait()
    
    // Final health report
    fmt.Println("\nπŸ“‹ Final Health Report:")
    health := cb.HealthCheck()
    for key, value := range health {
        fmt.Printf("  %s: %v\n", key, value)
    }
}

HTTP Client with Circuit Breaker

package main

import (
    "context"
    "encoding/json"
    "fmt"
    "io"
    "net/http"
    "time"
)

// HTTPClient wraps http.Client with circuit breaker
type HTTPClient struct {
    client         *http.Client
    circuitBreaker *AdvancedCircuitBreaker
}

// NewHTTPClient creates a new HTTP client with circuit breaker
func NewHTTPClient(timeout time.Duration, cbConfig AdvancedConfig) *HTTPClient {
    return &HTTPClient{
        client: &http.Client{
            Timeout: timeout,
        },
        circuitBreaker: NewAdvancedCircuitBreaker(cbConfig),
    }
}

// Get performs a GET request with circuit breaker protection
func (hc *HTTPClient) Get(ctx context.Context, url string) (*http.Response, error) {
    var resp *http.Response
    
    err := hc.circuitBreaker.ExecuteWithContext(ctx, func(ctx context.Context) error {
        req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
        if err != nil {
            return err
        }
        
        var httpErr error
        resp, httpErr = hc.client.Do(req)
        if httpErr != nil {
            return httpErr
        }
        
        // Consider 5xx status codes as failures
        if resp.StatusCode >= 500 {
            resp.Body.Close()
            return fmt.Errorf("server error: %d", resp.StatusCode)
        }
        
        return nil
    })
    
    return resp, err
}

// GetJSON performs a GET request and unmarshals JSON response
func (hc *HTTPClient) GetJSON(ctx context.Context, url string, target interface{}) error {
    resp, err := hc.Get(ctx, url)
    if err != nil {
        return err
    }
    defer resp.Body.Close()
    
    body, err := io.ReadAll(resp.Body)
    if err != nil {
        return err
    }
    
    return json.Unmarshal(body, target)
}

// GetHealth returns circuit breaker health
func (hc *HTTPClient) GetHealth() map[string]interface{} {
    return hc.circuitBreaker.HealthCheck()
}

// Example usage
func main() {
    config := AdvancedConfig{
        MaxFailures:      3,
        ResetTimeout:     5 * time.Second,
        HalfOpenMaxCalls: 2,
        CallTimeout:      2 * time.Second,
        OnStateChange: func(from, to State) {
            fmt.Printf("πŸ”„ HTTP Client circuit breaker: %s -> %s\n", from, to)
        },
    }
    
    client := NewHTTPClient(3*time.Second, config)
    
    // Test URLs (some will fail)
    urls := []string{
        "https://httpbin.org/status/200",  // Success
        "https://httpbin.org/status/500",  // Server error
        "https://httpbin.org/delay/1",     // Success with delay
        "https://httpbin.org/status/503",  // Server error
        "https://httpbin.org/status/500",  // Server error
        "https://httpbin.org/status/502",  // Server error (should open circuit)
        "https://httpbin.org/status/200",  // Should be rejected
        "https://httpbin.org/status/200",  // Should be rejected
    }
    
    for i, url := range urls {
        fmt.Printf("\n--- Request %d: %s ---\n", i+1, url)
        
        ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
        
        resp, err := client.Get(ctx, url)
        if err != nil {
            fmt.Printf("❌ Error: %v\n", err)
        } else {
            fmt.Printf("βœ… Success: %d %s\n", resp.StatusCode, resp.Status)
            resp.Body.Close()
        }
        
        cancel()
        
        // Show current health
        health := client.GetHealth()
        fmt.Printf("State: %s, Success Rate: %s\n", 
            health["state"], health["success_rate"])
        
        time.Sleep(1 * time.Second)
    }
    
    // Wait for circuit to potentially reset
    fmt.Println("\n--- Waiting for potential reset ---")
    time.Sleep(6 * time.Second)
    
    // Try again
    fmt.Println("\n--- Testing recovery ---")
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    
    resp, err := client.Get(ctx, "https://httpbin.org/status/200")
    if err != nil {
        fmt.Printf("❌ Recovery test failed: %v\n", err)
    } else {
        fmt.Printf("βœ… Recovery test succeeded: %d %s\n", resp.StatusCode, resp.Status)
        resp.Body.Close()
    }
    
    // Final health report
    fmt.Println("\nπŸ“‹ Final Health Report:")
    health := client.GetHealth()
    for key, value := range health {
        fmt.Printf("  %s: %v\n", key, value)
    }
}

Best Practices

  1. Choose Appropriate Thresholds: Set failure thresholds based on service characteristics
  2. Monitor State Changes: Log state transitions for debugging
  3. Implement Fallbacks: Provide alternative responses when circuit is open
  4. Use Timeouts: Combine with timeouts to handle slow responses
  5. Gradual Recovery: Use half-open state to test service recovery
  6. Metrics Collection: Track success rates, response times, and state changes
  7. Configuration: Make thresholds configurable for different environments

Common Pitfalls

  1. Too Sensitive: Setting thresholds too low causes unnecessary trips
  2. Too Tolerant: High thresholds don’t protect against cascading failures
  3. No Fallbacks: Not providing alternative responses when circuit is open
  4. Ignoring Context: Not respecting context cancellation in protected functions
  5. Poor Monitoring: Not tracking circuit breaker metrics and health

The Circuit Breaker pattern is essential for building resilient distributed systems. It prevents cascading failures, provides fast failure responses, and allows services time to recover, making your applications more robust and reliable.


Next: Learn about Rate Limiter Pattern for controlling the rate of operations and preventing system overload.