Go Concurrency Patterns Series: ← Once Pattern | Series Overview | Circuit Breaker →


What is the Context Pattern?

The Context pattern uses Go’s context package to carry cancellation signals, deadlines, timeouts, and request-scoped values across API boundaries and between goroutines. It’s essential for building responsive, cancellable operations and managing request lifecycles.

Key Features:

  • Cancellation: Signal when operations should stop
  • Timeouts: Automatically cancel after a duration
  • Deadlines: Cancel at a specific time
  • Values: Carry request-scoped data

Real-World Use Cases

  • HTTP Servers: Request cancellation and timeouts
  • Database Operations: Query timeouts and cancellation
  • API Calls: External service timeouts
  • Background Jobs: Graceful shutdown
  • Microservices: Request tracing and correlation IDs
  • File Operations: Long-running I/O with cancellation

Basic Context Usage

package main

import (
    "context"
    "fmt"
    "math/rand"
    "time"
)

// simulateWork simulates a long-running operation
func simulateWork(ctx context.Context, name string, duration time.Duration) error {
    fmt.Printf("%s: Starting work (expected duration: %v)\n", name, duration)
    
    select {
    case <-time.After(duration):
        fmt.Printf("%s: Work completed successfully\n", name)
        return nil
    case <-ctx.Done():
        fmt.Printf("%s: Work cancelled: %v\n", name, ctx.Err())
        return ctx.Err()
    }
}

func main() {
    // Example 1: Context with timeout
    fmt.Println("=== Context with Timeout ===")
    ctx1, cancel1 := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel1()
    
    err := simulateWork(ctx1, "Task1", 1*time.Second) // Should complete
    if err != nil {
        fmt.Printf("Task1 error: %v\n", err)
    }
    
    err = simulateWork(ctx1, "Task2", 3*time.Second) // Should timeout
    if err != nil {
        fmt.Printf("Task2 error: %v\n", err)
    }
    
    // Example 2: Manual cancellation
    fmt.Println("\n=== Manual Cancellation ===")
    ctx2, cancel2 := context.WithCancel(context.Background())
    
    go func() {
        time.Sleep(1 * time.Second)
        fmt.Println("Cancelling context...")
        cancel2()
    }()
    
    err = simulateWork(ctx2, "Task3", 3*time.Second) // Should be cancelled
    if err != nil {
        fmt.Printf("Task3 error: %v\n", err)
    }
    
    // Example 3: Context with deadline
    fmt.Println("\n=== Context with Deadline ===")
    deadline := time.Now().Add(1500 * time.Millisecond)
    ctx3, cancel3 := context.WithDeadline(context.Background(), deadline)
    defer cancel3()
    
    err = simulateWork(ctx3, "Task4", 2*time.Second) // Should hit deadline
    if err != nil {
        fmt.Printf("Task4 error: %v\n", err)
    }
}

Context with Values

package main

import (
    "context"
    "fmt"
    "log"
    "net/http"
    "time"
)

// Key types for context values
type contextKey string

const (
    RequestIDKey contextKey = "requestID"
    UserIDKey    contextKey = "userID"
    TraceIDKey   contextKey = "traceID"
)

// RequestInfo holds request-scoped information
type RequestInfo struct {
    RequestID string
    UserID    string
    TraceID   string
    StartTime time.Time
}

// withRequestInfo adds request information to context
func withRequestInfo(ctx context.Context, info RequestInfo) context.Context {
    ctx = context.WithValue(ctx, RequestIDKey, info.RequestID)
    ctx = context.WithValue(ctx, UserIDKey, info.UserID)
    ctx = context.WithValue(ctx, TraceIDKey, info.TraceID)
    return ctx
}

// getRequestID extracts request ID from context
func getRequestID(ctx context.Context) string {
    if id, ok := ctx.Value(RequestIDKey).(string); ok {
        return id
    }
    return "unknown"
}

// getUserID extracts user ID from context
func getUserID(ctx context.Context) string {
    if id, ok := ctx.Value(UserIDKey).(string); ok {
        return id
    }
    return "anonymous"
}

// getTraceID extracts trace ID from context
func getTraceID(ctx context.Context) string {
    if id, ok := ctx.Value(TraceIDKey).(string); ok {
        return id
    }
    return "no-trace"
}

// logWithContext logs with context information
func logWithContext(ctx context.Context, message string) {
    requestID := getRequestID(ctx)
    userID := getUserID(ctx)
    traceID := getTraceID(ctx)
    
    fmt.Printf("[%s][%s][%s] %s\n", requestID, userID, traceID, message)
}

// businessLogic simulates business logic that uses context
func businessLogic(ctx context.Context) error {
    logWithContext(ctx, "Starting business logic")
    
    // Simulate some work
    select {
    case <-time.After(500 * time.Millisecond):
        logWithContext(ctx, "Business logic completed")
        return nil
    case <-ctx.Done():
        logWithContext(ctx, "Business logic cancelled")
        return ctx.Err()
    }
}

// databaseOperation simulates a database operation
func databaseOperation(ctx context.Context, query string) error {
    logWithContext(ctx, fmt.Sprintf("Executing query: %s", query))
    
    select {
    case <-time.After(200 * time.Millisecond):
        logWithContext(ctx, "Database operation completed")
        return nil
    case <-ctx.Done():
        logWithContext(ctx, "Database operation cancelled")
        return ctx.Err()
    }
}

// externalAPICall simulates calling an external API
func externalAPICall(ctx context.Context, endpoint string) error {
    logWithContext(ctx, fmt.Sprintf("Calling external API: %s", endpoint))
    
    select {
    case <-time.After(300 * time.Millisecond):
        logWithContext(ctx, "External API call completed")
        return nil
    case <-ctx.Done():
        logWithContext(ctx, "External API call cancelled")
        return ctx.Err()
    }
}

// handleRequest simulates handling an HTTP request
func handleRequest(ctx context.Context) error {
    logWithContext(ctx, "Handling request")
    
    // Perform multiple operations
    if err := databaseOperation(ctx, "SELECT * FROM users"); err != nil {
        return err
    }
    
    if err := externalAPICall(ctx, "/api/v1/data"); err != nil {
        return err
    }
    
    if err := businessLogic(ctx); err != nil {
        return err
    }
    
    logWithContext(ctx, "Request handled successfully")
    return nil
}

func main() {
    // Simulate incoming request
    requestInfo := RequestInfo{
        RequestID: "req-12345",
        UserID:    "user-67890",
        TraceID:   "trace-abcdef",
        StartTime: time.Now(),
    }
    
    // Create context with timeout and request info
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    ctx = withRequestInfo(ctx, requestInfo)
    
    // Handle the request
    if err := handleRequest(ctx); err != nil {
        logWithContext(ctx, fmt.Sprintf("Request failed: %v", err))
    }
    
    // Example with early cancellation
    fmt.Println("\n=== Early Cancellation Example ===")
    ctx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second)
    
    requestInfo2 := RequestInfo{
        RequestID: "req-54321",
        UserID:    "user-09876",
        TraceID:   "trace-fedcba",
        StartTime: time.Now(),
    }
    
    ctx2 = withRequestInfo(ctx2, requestInfo2)
    
    // Cancel after 800ms
    go func() {
        time.Sleep(800 * time.Millisecond)
        logWithContext(ctx2, "Cancelling request early")
        cancel2()
    }()
    
    if err := handleRequest(ctx2); err != nil {
        logWithContext(ctx2, fmt.Sprintf("Request failed: %v", err))
    }
}

HTTP Server with Context

package main

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "math/rand"
    "net/http"
    "strconv"
    "time"
)

// Response represents an API response
type Response struct {
    Message   string        `json:"message"`
    RequestID string        `json:"request_id"`
    Duration  time.Duration `json:"duration"`
    Data      interface{}   `json:"data,omitempty"`
}

// middleware adds request ID and timeout to context
func middleware(next http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        // Generate request ID
        requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
        
        // Get timeout from query parameter (default 5 seconds)
        timeoutStr := r.URL.Query().Get("timeout")
        timeout := 5 * time.Second
        if timeoutStr != "" {
            if t, err := time.ParseDuration(timeoutStr); err == nil {
                timeout = t
            }
        }
        
        // Create context with timeout
        ctx, cancel := context.WithTimeout(r.Context(), timeout)
        defer cancel()
        
        // Add request ID to context
        ctx = context.WithValue(ctx, RequestIDKey, requestID)
        
        // Create new request with updated context
        r = r.WithContext(ctx)
        
        // Add request ID to response headers
        w.Header().Set("X-Request-ID", requestID)
        
        next(w, r)
    }
}

// simulateSlowOperation simulates a slow operation that respects context
func simulateSlowOperation(ctx context.Context, duration time.Duration) (string, error) {
    select {
    case <-time.After(duration):
        return fmt.Sprintf("Operation completed after %v", duration), nil
    case <-ctx.Done():
        return "", ctx.Err()
    }
}

// fastHandler handles requests quickly
func fastHandler(w http.ResponseWriter, r *http.Request) {
    start := time.Now()
    ctx := r.Context()
    requestID := getRequestID(ctx)
    
    result, err := simulateSlowOperation(ctx, 100*time.Millisecond)
    duration := time.Since(start)
    
    response := Response{
        RequestID: requestID,
        Duration:  duration,
    }
    
    if err != nil {
        response.Message = "Request failed"
        w.WriteHeader(http.StatusRequestTimeout)
    } else {
        response.Message = "Success"
        response.Data = result
    }
    
    json.NewEncoder(w).Encode(response)
}

// slowHandler handles requests that might timeout
func slowHandler(w http.ResponseWriter, r *http.Request) {
    start := time.Now()
    ctx := r.Context()
    requestID := getRequestID(ctx)
    
    // Random duration between 1-10 seconds
    duration := time.Duration(1+rand.Intn(10)) * time.Second
    
    result, err := simulateSlowOperation(ctx, duration)
    elapsed := time.Since(start)
    
    response := Response{
        RequestID: requestID,
        Duration:  elapsed,
    }
    
    if err != nil {
        response.Message = "Request timed out or cancelled"
        w.WriteHeader(http.StatusRequestTimeout)
    } else {
        response.Message = "Success"
        response.Data = result
    }
    
    json.NewEncoder(w).Encode(response)
}

// batchHandler processes multiple operations
func batchHandler(w http.ResponseWriter, r *http.Request) {
    start := time.Now()
    ctx := r.Context()
    requestID := getRequestID(ctx)
    
    // Get batch size from query parameter
    batchSizeStr := r.URL.Query().Get("size")
    batchSize := 3
    if batchSizeStr != "" {
        if size, err := strconv.Atoi(batchSizeStr); err == nil && size > 0 {
            batchSize = size
        }
    }
    
    results := make([]string, 0, batchSize)
    
    // Process operations sequentially, checking context each time
    for i := 0; i < batchSize; i++ {
        select {
        case <-ctx.Done():
            // Context cancelled, return partial results
            response := Response{
                RequestID: requestID,
                Duration:  time.Since(start),
                Message:   fmt.Sprintf("Batch cancelled after %d/%d operations", i, batchSize),
                Data:      results,
            }
            w.WriteHeader(http.StatusRequestTimeout)
            json.NewEncoder(w).Encode(response)
            return
        default:
        }
        
        result, err := simulateSlowOperation(ctx, 200*time.Millisecond)
        if err != nil {
            response := Response{
                RequestID: requestID,
                Duration:  time.Since(start),
                Message:   fmt.Sprintf("Batch failed at operation %d: %v", i+1, err),
                Data:      results,
            }
            w.WriteHeader(http.StatusRequestTimeout)
            json.NewEncoder(w).Encode(response)
            return
        }
        
        results = append(results, fmt.Sprintf("Op%d: %s", i+1, result))
    }
    
    response := Response{
        RequestID: requestID,
        Duration:  time.Since(start),
        Message:   "Batch completed successfully",
        Data:      results,
    }
    
    json.NewEncoder(w).Encode(response)
}

func main() {
    http.HandleFunc("/fast", middleware(fastHandler))
    http.HandleFunc("/slow", middleware(slowHandler))
    http.HandleFunc("/batch", middleware(batchHandler))
    
    fmt.Println("Server starting on :8080")
    fmt.Println("Endpoints:")
    fmt.Println("  GET /fast - Fast operation (100ms)")
    fmt.Println("  GET /slow - Slow operation (1-10s random)")
    fmt.Println("  GET /batch?size=N - Batch operations")
    fmt.Println("  Add ?timeout=5s to set custom timeout")
    
    log.Fatal(http.ListenAndServe(":8080", nil))
}

Context Propagation in Goroutines

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// Worker represents a worker that processes tasks
type Worker struct {
    ID   int
    Name string
}

// ProcessTask processes a task with context
func (w *Worker) ProcessTask(ctx context.Context, taskID int) error {
    requestID := getRequestID(ctx)
    
    fmt.Printf("Worker %d (%s) [%s]: Starting task %d\n", 
        w.ID, w.Name, requestID, taskID)
    
    // Simulate work with multiple steps
    for step := 1; step <= 3; step++ {
        select {
        case <-time.After(200 * time.Millisecond):
            fmt.Printf("Worker %d (%s) [%s]: Task %d step %d completed\n", 
                w.ID, w.Name, requestID, taskID, step)
        case <-ctx.Done():
            fmt.Printf("Worker %d (%s) [%s]: Task %d cancelled at step %d: %v\n", 
                w.ID, w.Name, requestID, taskID, step, ctx.Err())
            return ctx.Err()
        }
    }
    
    fmt.Printf("Worker %d (%s) [%s]: Task %d completed successfully\n", 
        w.ID, w.Name, requestID, taskID)
    return nil
}

// TaskManager manages task distribution
type TaskManager struct {
    workers []Worker
}

// NewTaskManager creates a new task manager
func NewTaskManager() *TaskManager {
    return &TaskManager{
        workers: []Worker{
            {ID: 1, Name: "Alice"},
            {ID: 2, Name: "Bob"},
            {ID: 3, Name: "Charlie"},
        },
    }
}

// ProcessTasksConcurrently processes tasks using multiple workers
func (tm *TaskManager) ProcessTasksConcurrently(ctx context.Context, taskCount int) error {
    var wg sync.WaitGroup
    taskChan := make(chan int, taskCount)
    errorChan := make(chan error, len(tm.workers))
    
    // Send tasks to channel
    go func() {
        defer close(taskChan)
        for i := 1; i <= taskCount; i++ {
            select {
            case taskChan <- i:
            case <-ctx.Done():
                return
            }
        }
    }()
    
    // Start workers
    for _, worker := range tm.workers {
        wg.Add(1)
        go func(w Worker) {
            defer wg.Done()
            
            for {
                select {
                case taskID, ok := <-taskChan:
                    if !ok {
                        return // No more tasks
                    }
                    
                    if err := w.ProcessTask(ctx, taskID); err != nil {
                        select {
                        case errorChan <- err:
                        case <-ctx.Done():
                        }
                        return
                    }
                    
                case <-ctx.Done():
                    return
                }
            }
        }(worker)
    }
    
    // Wait for completion or cancellation
    done := make(chan struct{})
    go func() {
        wg.Wait()
        close(done)
    }()
    
    select {
    case <-done:
        close(errorChan)
        // Check for errors
        for err := range errorChan {
            if err != nil {
                return err
            }
        }
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

func main() {
    manager := NewTaskManager()
    
    // Example 1: Normal completion
    fmt.Println("=== Normal Completion ===")
    ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
    ctx1 = context.WithValue(ctx1, RequestIDKey, "batch-001")
    defer cancel1()
    
    err := manager.ProcessTasksConcurrently(ctx1, 6)
    if err != nil {
        fmt.Printf("Batch processing failed: %v\n", err)
    } else {
        fmt.Println("Batch processing completed successfully")
    }
    
    time.Sleep(1 * time.Second)
    
    // Example 2: Timeout scenario
    fmt.Println("\n=== Timeout Scenario ===")
    ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
    ctx2 = context.WithValue(ctx2, RequestIDKey, "batch-002")
    defer cancel2()
    
    err = manager.ProcessTasksConcurrently(ctx2, 10)
    if err != nil {
        fmt.Printf("Batch processing failed: %v\n", err)
    } else {
        fmt.Println("Batch processing completed successfully")
    }
    
    time.Sleep(1 * time.Second)
    
    // Example 3: Manual cancellation
    fmt.Println("\n=== Manual Cancellation ===")
    ctx3, cancel3 := context.WithCancel(context.Background())
    ctx3 = context.WithValue(ctx3, RequestIDKey, "batch-003")
    
    // Cancel after 800ms
    go func() {
        time.Sleep(800 * time.Millisecond)
        fmt.Println("Manually cancelling batch...")
        cancel3()
    }()
    
    err = manager.ProcessTasksConcurrently(ctx3, 8)
    if err != nil {
        fmt.Printf("Batch processing failed: %v\n", err)
    } else {
        fmt.Println("Batch processing completed successfully")
    }
}

Best Practices

  1. Always Accept Context: Functions that might block should accept context as first parameter
  2. Don’t Store Context: Pass context as parameter, don’t store in structs
  3. Use context.TODO(): When you don’t have context but need one
  4. Derive Contexts: Create child contexts from parent contexts
  5. Handle Cancellation: Always check ctx.Done() in long-running operations
  6. Limit Context Values: Use sparingly and for request-scoped data only
  7. Use Typed Keys: Define custom types for context keys to avoid collisions

Common Pitfalls

1. Ignoring Context Cancellation

// ❌ Bad: Ignoring context cancellation
func badOperation(ctx context.Context) error {
    for i := 0; i < 1000; i++ {
        // Long operation without checking context
        time.Sleep(10 * time.Millisecond)
        // Process item i
    }
    return nil
}

// ✅ Good: Checking context regularly
func goodOperation(ctx context.Context) error {
    for i := 0; i < 1000; i++ {
        select {
        case <-ctx.Done():
            return ctx.Err()
        default:
        }
        
        time.Sleep(10 * time.Millisecond)
        // Process item i
    }
    return nil
}

2. Using Context for Optional Parameters

// ❌ Bad: Using context for optional parameters
func badFunction(ctx context.Context) {
    if timeout, ok := ctx.Value("timeout").(time.Duration); ok {
        // Use timeout
    }
}

// ✅ Good: Use function parameters for optional values
func goodFunction(ctx context.Context, timeout time.Duration) {
    // Use timeout parameter
}

The Context pattern is fundamental for building robust, cancellable operations in Go. It enables graceful handling of timeouts, cancellations, and request-scoped data, making your applications more responsive and resource-efficient.


Next: Learn about Circuit Breaker Pattern for fault tolerance and resilience in distributed systems.