Go Concurrency Patterns Series: ← Circuit Breaker | Series Overview | Memory Model →


What is Context Propagation?

Context propagation is the practice of threading context.Context through your application to carry cancellation signals, deadlines, and request-scoped values across API boundaries, goroutines, and service boundaries. This is critical for building observable, responsive distributed systems.

Key Capabilities:

  • Distributed Tracing: Propagate trace IDs across services
  • Cancellation Cascades: Cancel entire request trees
  • Deadline Enforcement: Ensure requests complete within time budgets
  • Request-Scoped Values: Carry metadata without polluting function signatures

Real-World Use Cases

  • Microservices: Trace requests across multiple services
  • API Gateways: Propagate timeouts and user context
  • Database Layers: Cancel queries when requests are abandoned
  • Message Queues: Propagate processing deadlines
  • HTTP Middleware: Extract and inject trace headers
  • gRPC Services: Automatic context propagation

Basic Context Propagation

Propagating Through Function Calls

package main

import (
	"context"
	"fmt"
	"time"
)

// ServiceA calls ServiceB which calls ServiceC
// Context propagates through all layers

func ServiceA(ctx context.Context, userID string) error {
	// Add request-scoped value
	ctx = context.WithValue(ctx, "user_id", userID)
	ctx = context.WithValue(ctx, "request_id", generateRequestID())

	fmt.Printf("[ServiceA] Processing request for user: %s\n", userID)

	// Propagate context to next service
	return ServiceB(ctx)
}

func ServiceB(ctx context.Context) error {
	// Retrieve values from context
	userID := ctx.Value("user_id").(string)
	requestID := ctx.Value("request_id").(string)

	fmt.Printf("[ServiceB] User: %s, Request: %s\n", userID, requestID)

	// Add timeout for downstream call
	ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
	defer cancel()

	return ServiceC(ctx)
}

func ServiceC(ctx context.Context) error {
	userID := ctx.Value("user_id").(string)
	requestID := ctx.Value("request_id").(string)

	fmt.Printf("[ServiceC] Processing for User: %s, Request: %s\n", userID, requestID)

	// Simulate work
	select {
	case <-time.After(1 * time.Second):
		fmt.Println("[ServiceC] Work completed")
		return nil
	case <-ctx.Done():
		fmt.Printf("[ServiceC] Cancelled: %v\n", ctx.Err())
		return ctx.Err()
	}
}

func generateRequestID() string {
	return fmt.Sprintf("req-%d", time.Now().UnixNano())
}

func main() {
	ctx := context.Background()

	err := ServiceA(ctx, "user-123")
	if err != nil {
		fmt.Printf("Error: %v\n", err)
	}
}

Output:

[ServiceA] Processing request for user: user-123
[ServiceB] User: user-123, Request: req-1234567890
[ServiceC] Processing for User: user-123, Request: req-1234567890
[ServiceC] Work completed

HTTP Request Context Propagation

Server-Side: Extracting and Propagating Context

package main

import (
	"context"
	"fmt"
	"net/http"
	"time"
)

type contextKey string

const (
	TraceIDKey   contextKey = "trace_id"
	RequestIDKey contextKey = "request_id"
	UserIDKey    contextKey = "user_id"
)

// Middleware to extract trace headers and add to context
func TracingMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx := r.Context()

		// Extract trace ID from header or generate new one
		traceID := r.Header.Get("X-Trace-ID")
		if traceID == "" {
			traceID = generateTraceID()
		}

		// Extract or generate request ID
		requestID := r.Header.Get("X-Request-ID")
		if requestID == "" {
			requestID = generateRequestID()
		}

		// Add to context
		ctx = context.WithValue(ctx, TraceIDKey, traceID)
		ctx = context.WithValue(ctx, RequestIDKey, requestID)

		// Add response headers
		w.Header().Set("X-Trace-ID", traceID)
		w.Header().Set("X-Request-ID", requestID)

		// Log request start
		fmt.Printf("[%s] %s %s started\n", traceID, r.Method, r.URL.Path)

		// Propagate context
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

// Timeout middleware
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ctx, cancel := context.WithTimeout(r.Context(), timeout)
			defer cancel()

			next.ServeHTTP(w, r.WithContext(ctx))
		})
	}
}

// Business logic handler
func handleUserRequest(w http.ResponseWriter, r *http.Request) {
	ctx := r.Context()

	// Extract trace info
	traceID := ctx.Value(TraceIDKey).(string)

	// Simulate calling another service
	userData, err := fetchUserData(ctx, "user-123")
	if err != nil {
		fmt.Printf("[%s] Error fetching user: %v\n", traceID, err)
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}

	fmt.Printf("[%s] Request completed successfully\n", traceID)
	fmt.Fprintf(w, "User data: %s\n", userData)
}

func fetchUserData(ctx context.Context, userID string) (string, error) {
	traceID := ctx.Value(TraceIDKey).(string)

	fmt.Printf("[%s] Fetching user data for: %s\n", traceID, userID)

	// Simulate database call
	select {
	case <-time.After(500 * time.Millisecond):
		return fmt.Sprintf("Data for %s", userID), nil
	case <-ctx.Done():
		return "", ctx.Err()
	}
}

func generateTraceID() string {
	return fmt.Sprintf("trace-%d", time.Now().UnixNano())
}

func main() {
	mux := http.NewServeMux()
	mux.HandleFunc("/user", handleUserRequest)

	// Apply middleware
	handler := TracingMiddleware(
		TimeoutMiddleware(5 * time.Second)(mux),
	)

	fmt.Println("Server starting on :8080")
	http.ListenAndServe(":8080", handler)
}

Client-Side: Propagating Context in HTTP Clients

package main

import (
	"context"
	"fmt"
	"io"
	"net/http"
	"time"
)

// HTTPClient wraps http.Client with context propagation
type HTTPClient struct {
	client *http.Client
}

func NewHTTPClient() *HTTPClient {
	return &HTTPClient{
		client: &http.Client{
			Timeout: 10 * time.Second,
		},
	}
}

// DoWithContext makes HTTP request with context propagation
func (c *HTTPClient) DoWithContext(ctx context.Context, method, url string) (string, error) {
	req, err := http.NewRequestWithContext(ctx, method, url, nil)
	if err != nil {
		return "", err
	}

	// Propagate trace headers
	if traceID := ctx.Value(TraceIDKey); traceID != nil {
		req.Header.Set("X-Trace-ID", traceID.(string))
	}

	if requestID := ctx.Value(RequestIDKey); requestID != nil {
		req.Header.Set("X-Request-ID", requestID.(string))
	}

	// Make request
	resp, err := c.client.Do(req)
	if err != nil {
		return "", err
	}
	defer resp.Body.Close()

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return "", err
	}

	return string(body), nil
}

// Example: Calling multiple services with context propagation
func orchestrateServices(ctx context.Context) error {
	client := NewHTTPClient()

	// Add trace context
	ctx = context.WithValue(ctx, TraceIDKey, generateTraceID())
	ctx = context.WithValue(ctx, RequestIDKey, generateRequestID())

	traceID := ctx.Value(TraceIDKey).(string)

	// Call service A
	fmt.Printf("[%s] Calling Service A\n", traceID)
	_, err := client.DoWithContext(ctx, "GET", "http://localhost:8080/user")
	if err != nil {
		return fmt.Errorf("service A failed: %w", err)
	}

	// Call service B (with shorter timeout)
	ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
	defer cancel()

	fmt.Printf("[%s] Calling Service B\n", traceID)
	_, err = client.DoWithContext(ctx, "GET", "http://localhost:8081/orders")
	if err != nil {
		return fmt.Errorf("service B failed: %w", err)
	}

	return nil
}

Database Context Propagation

package main

import (
	"context"
	"database/sql"
	"fmt"
	"time"
)

type UserRepository struct {
	db *sql.DB
}

// QueryWithContext propagates context to database layer
func (r *UserRepository) GetUser(ctx context.Context, userID int) (*User, error) {
	// Extract trace info for logging
	traceID := ctx.Value(TraceIDKey)

	fmt.Printf("[%s] DB Query: GetUser(%d)\n", traceID, userID)

	// Context is propagated to DB - query will be cancelled if context is cancelled
	query := "SELECT id, name, email FROM users WHERE id = ?"

	row := r.db.QueryRowContext(ctx, query, userID)

	var user User
	err := row.Scan(&user.ID, &user.Name, &user.Email)
	if err != nil {
		return nil, err
	}

	return &user, nil
}

// Transaction with context
func (r *UserRepository) UpdateUserWithContext(ctx context.Context, user *User) error {
	traceID := ctx.Value(TraceIDKey)

	// Start transaction with context
	tx, err := r.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	defer tx.Rollback()

	fmt.Printf("[%s] DB Transaction: UpdateUser(%d)\n", traceID, user.ID)

	// All operations in transaction respect context
	_, err = tx.ExecContext(ctx,
		"UPDATE users SET name = ?, email = ? WHERE id = ?",
		user.Name, user.Email, user.ID,
	)
	if err != nil {
		return err
	}

	// Check context before committing
	if ctx.Err() != nil {
		return ctx.Err()
	}

	return tx.Commit()
}

type User struct {
	ID    int
	Name  string
	Email string
}

Distributed Tracing Pattern

Complete Tracing Implementation

package main

import (
	"context"
	"fmt"
	"time"
)

// Span represents a single unit of work
type Span struct {
	TraceID   string
	SpanID    string
	ParentID  string
	Name      string
	StartTime time.Time
	EndTime   time.Time
	Tags      map[string]string
}

// Tracer manages trace context
type Tracer struct {
	spans []*Span
}

func NewTracer() *Tracer {
	return &Tracer{
		spans: make([]*Span, 0),
	}
}

type spanKey struct{}

// StartSpan creates a new span and adds it to context
func (t *Tracer) StartSpan(ctx context.Context, name string) (context.Context, *Span) {
	parentSpan, _ := ctx.Value(spanKey{}).(*Span)

	span := &Span{
		SpanID:    generateSpanID(),
		Name:      name,
		StartTime: time.Now(),
		Tags:      make(map[string]string),
	}

	if parentSpan != nil {
		span.TraceID = parentSpan.TraceID
		span.ParentID = parentSpan.SpanID
	} else {
		span.TraceID = generateTraceID()
	}

	fmt.Printf("[TRACE] Started span: %s (trace: %s, span: %s)\n",
		name, span.TraceID, span.SpanID)

	ctx = context.WithValue(ctx, spanKey{}, span)
	return ctx, span
}

// FinishSpan completes a span
func (t *Tracer) FinishSpan(span *Span) {
	span.EndTime = time.Now()
	duration := span.EndTime.Sub(span.StartTime)

	fmt.Printf("[TRACE] Finished span: %s (duration: %v)\n", span.Name, duration)

	t.spans = append(t.spans, span)
}

// AddTag adds metadata to current span
func AddSpanTag(ctx context.Context, key, value string) {
	if span, ok := ctx.Value(spanKey{}).(*Span); ok {
		span.Tags[key] = value
	}
}

// Example: Multi-service call with distributed tracing
func ProcessOrder(ctx context.Context, orderID string) error {
	tracer := NewTracer()

	// Start root span
	ctx, rootSpan := tracer.StartSpan(ctx, "ProcessOrder")
	defer tracer.FinishSpan(rootSpan)

	AddSpanTag(ctx, "order.id", orderID)

	// Validate order
	if err := ValidateOrder(ctx, tracer, orderID); err != nil {
		AddSpanTag(ctx, "error", err.Error())
		return err
	}

	// Process payment
	if err := ProcessPayment(ctx, tracer, orderID); err != nil {
		AddSpanTag(ctx, "error", err.Error())
		return err
	}

	// Ship order
	if err := ShipOrder(ctx, tracer, orderID); err != nil {
		AddSpanTag(ctx, "error", err.Error())
		return err
	}

	return nil
}

func ValidateOrder(ctx context.Context, tracer *Tracer, orderID string) error {
	ctx, span := tracer.StartSpan(ctx, "ValidateOrder")
	defer tracer.FinishSpan(span)

	AddSpanTag(ctx, "service", "validation")

	// Simulate validation
	time.Sleep(100 * time.Millisecond)
	return nil
}

func ProcessPayment(ctx context.Context, tracer *Tracer, orderID string) error {
	ctx, span := tracer.StartSpan(ctx, "ProcessPayment")
	defer tracer.FinishSpan(span)

	AddSpanTag(ctx, "service", "payment")

	// Call payment gateway
	if err := CallPaymentGateway(ctx, tracer, orderID); err != nil {
		return err
	}

	return nil
}

func CallPaymentGateway(ctx context.Context, tracer *Tracer, orderID string) error {
	ctx, span := tracer.StartSpan(ctx, "CallPaymentGateway")
	defer tracer.FinishSpan(span)

	AddSpanTag(ctx, "service", "external")
	AddSpanTag(ctx, "gateway", "stripe")

	// Simulate API call
	time.Sleep(200 * time.Millisecond)
	return nil
}

func ShipOrder(ctx context.Context, tracer *Tracer, orderID string) error {
	ctx, span := tracer.StartSpan(ctx, "ShipOrder")
	defer tracer.FinishSpan(span)

	AddSpanTag(ctx, "service", "shipping")

	// Simulate shipping
	time.Sleep(150 * time.Millisecond)
	return nil
}

func generateSpanID() string {
	return fmt.Sprintf("span-%d", time.Now().UnixNano())
}

Deadline Propagation

Cascading Deadlines Across Services

package main

import (
	"context"
	"fmt"
	"time"
)

// ServiceOrchestrator manages deadline propagation
type ServiceOrchestrator struct {
	totalTimeout time.Duration
}

func NewServiceOrchestrator(timeout time.Duration) *ServiceOrchestrator {
	return &ServiceOrchestrator{
		totalTimeout: timeout,
	}
}

func (s *ServiceOrchestrator) Execute(ctx context.Context) error {
	// Set overall deadline
	ctx, cancel := context.WithTimeout(ctx, s.totalTimeout)
	defer cancel()

	deadline, _ := ctx.Deadline()
	remaining := time.Until(deadline)

	fmt.Printf("Total deadline: %v remaining\n", remaining)

	// Step 1: Use 30% of time budget
	step1Duration := time.Duration(float64(remaining) * 0.3)
	if err := s.executeStep1(ctx, step1Duration); err != nil {
		return err
	}

	// Step 2: Use 40% of remaining time
	remaining = time.Until(deadline)
	step2Duration := time.Duration(float64(remaining) * 0.4)
	if err := s.executeStep2(ctx, step2Duration); err != nil {
		return err
	}

	// Step 3: Use remaining time
	if err := s.executeStep3(ctx); err != nil {
		return err
	}

	return nil
}

func (s *ServiceOrchestrator) executeStep1(ctx context.Context, budget time.Duration) error {
	ctx, cancel := context.WithTimeout(ctx, budget)
	defer cancel()

	fmt.Printf("Step 1: Budget %v\n", budget)

	select {
	case <-time.After(budget / 2):
		fmt.Println("Step 1: Completed")
		return nil
	case <-ctx.Done():
		return fmt.Errorf("step 1 timeout: %w", ctx.Err())
	}
}

func (s *ServiceOrchestrator) executeStep2(ctx context.Context, budget time.Duration) error {
	ctx, cancel := context.WithTimeout(ctx, budget)
	defer cancel()

	fmt.Printf("Step 2: Budget %v\n", budget)

	select {
	case <-time.After(budget / 2):
		fmt.Println("Step 2: Completed")
		return nil
	case <-ctx.Done():
		return fmt.Errorf("step 2 timeout: %w", ctx.Err())
	}
}

func (s *ServiceOrchestrator) executeStep3(ctx context.Context) error {
	deadline, ok := ctx.Deadline()
	if !ok {
		return fmt.Errorf("no deadline set")
	}

	budget := time.Until(deadline)
	fmt.Printf("Step 3: Budget %v\n", budget)

	select {
	case <-time.After(budget / 2):
		fmt.Println("Step 3: Completed")
		return nil
	case <-ctx.Done():
		return fmt.Errorf("step 3 timeout: %w", ctx.Err())
	}
}

func main() {
	orchestrator := NewServiceOrchestrator(5 * time.Second)

	ctx := context.Background()
	err := orchestrator.Execute(ctx)
	if err != nil {
		fmt.Printf("Error: %v\n", err)
	}
}

Best Practices

1. Always Propagate Context

// BAD: Creating new context
func ServiceB() {
	ctx := context.Background() // Loses parent context!
	// ...
}

// GOOD: Accepting and propagating context
func ServiceB(ctx context.Context) {
	// Context propagates through
}

2. Don’t Store Context in Structs

// BAD: Storing context in struct
type Service struct {
	ctx context.Context // Anti-pattern!
}

// GOOD: Pass context as first parameter
type Service struct {
	// No context field
}

func (s *Service) DoWork(ctx context.Context) {
	// Context as parameter
}

3. Use Type-Safe Context Keys

// BAD: String keys can collide
ctx = context.WithValue(ctx, "user_id", 123)

// GOOD: Unexported type prevents collisions
type contextKey string

const userIDKey contextKey = "user_id"

ctx = context.WithValue(ctx, userIDKey, 123)

4. Handle Missing Context Values

// BAD: Panic on missing value
userID := ctx.Value(userIDKey).(string)

// GOOD: Safe extraction
userID, ok := ctx.Value(userIDKey).(string)
if !ok {
	userID = "unknown"
}

Common Pitfalls

  1. Creating New Background Contexts: Breaks propagation chain
  2. Not Checking ctx.Done(): Ignoring cancellation signals
  3. Passing nil Context: Always use context.Background() or context.TODO()
  4. Storing Context: Context should flow through calls, not be stored
  5. Using Context for Optional Parameters: Use context only for request-scoped values

Performance Considerations

  • Context propagation is lightweight (minimal overhead)
  • Value lookups traverse parent chain (O(n) where n is depth)
  • Keep context value chain shallow
  • Use context values sparingly for hot paths
  • Consider caching frequently accessed values

Testing Context Propagation

package main

import (
	"context"
	"testing"
	"time"
)

func TestContextPropagation(t *testing.T) {
	t.Run("trace ID propagates", func(t *testing.T) {
		ctx := context.WithValue(context.Background(), TraceIDKey, "test-trace-123")

		// Call service that should propagate context
		err := ServiceA(ctx, "user-1")
		if err != nil {
			t.Fatalf("unexpected error: %v", err)
		}

		// Verify trace ID was propagated (check logs or traces)
	})

	t.Run("cancellation propagates", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())

		done := make(chan error)
		go func() {
			done <- longRunningOperation(ctx)
		}()

		// Cancel context
		cancel()

		// Verify cancellation
		select {
		case err := <-done:
			if err != context.Canceled {
				t.Errorf("expected context.Canceled, got %v", err)
			}
		case <-time.After(1 * time.Second):
			t.Fatal("operation did not respect cancellation")
		}
	})
}

func longRunningOperation(ctx context.Context) error {
	select {
	case <-time.After(10 * time.Second):
		return nil
	case <-ctx.Done():
		return ctx.Err()
	}
}

Conclusion

Context propagation is essential for building observable, responsive distributed systems in Go. By threading context through your application, you enable:

  • Distributed tracing across service boundaries
  • Graceful cancellation of entire request trees
  • Deadline enforcement for time-sensitive operations
  • Request-scoped metadata without polluting signatures

Key Takeaways:

  • Always accept context as first parameter
  • Propagate context through all layers
  • Use type-safe keys for context values
  • Respect cancellation signals with ctx.Done()
  • Set appropriate timeouts and deadlines

Next, explore the Go Memory Model to understand how Go guarantees visibility and ordering across goroutines.


Previous: Circuit Breaker Pattern Next: Go Memory Model Explained Series: Go Concurrency Patterns