Go Concurrency Patterns Series: ← Circuit Breaker | Series Overview | Memory Model →
What is Context Propagation?
Context propagation is the practice of threading context.Context through your application to carry cancellation signals, deadlines, and request-scoped values across API boundaries, goroutines, and service boundaries. This is critical for building observable, responsive distributed systems.
Key Capabilities:
- Distributed Tracing: Propagate trace IDs across services
- Cancellation Cascades: Cancel entire request trees
- Deadline Enforcement: Ensure requests complete within time budgets
- Request-Scoped Values: Carry metadata without polluting function signatures
Real-World Use Cases
- Microservices: Trace requests across multiple services
- API Gateways: Propagate timeouts and user context
- Database Layers: Cancel queries when requests are abandoned
- Message Queues: Propagate processing deadlines
- HTTP Middleware: Extract and inject trace headers
- gRPC Services: Automatic context propagation
Basic Context Propagation
Propagating Through Function Calls
package main
import (
"context"
"fmt"
"time"
)
// ServiceA calls ServiceB which calls ServiceC
// Context propagates through all layers
func ServiceA(ctx context.Context, userID string) error {
// Add request-scoped value
ctx = context.WithValue(ctx, "user_id", userID)
ctx = context.WithValue(ctx, "request_id", generateRequestID())
fmt.Printf("[ServiceA] Processing request for user: %s\n", userID)
// Propagate context to next service
return ServiceB(ctx)
}
func ServiceB(ctx context.Context) error {
// Retrieve values from context
userID := ctx.Value("user_id").(string)
requestID := ctx.Value("request_id").(string)
fmt.Printf("[ServiceB] User: %s, Request: %s\n", userID, requestID)
// Add timeout for downstream call
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
return ServiceC(ctx)
}
func ServiceC(ctx context.Context) error {
userID := ctx.Value("user_id").(string)
requestID := ctx.Value("request_id").(string)
fmt.Printf("[ServiceC] Processing for User: %s, Request: %s\n", userID, requestID)
// Simulate work
select {
case <-time.After(1 * time.Second):
fmt.Println("[ServiceC] Work completed")
return nil
case <-ctx.Done():
fmt.Printf("[ServiceC] Cancelled: %v\n", ctx.Err())
return ctx.Err()
}
}
func generateRequestID() string {
return fmt.Sprintf("req-%d", time.Now().UnixNano())
}
func main() {
ctx := context.Background()
err := ServiceA(ctx, "user-123")
if err != nil {
fmt.Printf("Error: %v\n", err)
}
}
Output:
[ServiceA] Processing request for user: user-123
[ServiceB] User: user-123, Request: req-1234567890
[ServiceC] Processing for User: user-123, Request: req-1234567890
[ServiceC] Work completed
HTTP Request Context Propagation
Server-Side: Extracting and Propagating Context
package main
import (
"context"
"fmt"
"net/http"
"time"
)
type contextKey string
const (
TraceIDKey contextKey = "trace_id"
RequestIDKey contextKey = "request_id"
UserIDKey contextKey = "user_id"
)
// Middleware to extract trace headers and add to context
func TracingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Extract trace ID from header or generate new one
traceID := r.Header.Get("X-Trace-ID")
if traceID == "" {
traceID = generateTraceID()
}
// Extract or generate request ID
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// Add to context
ctx = context.WithValue(ctx, TraceIDKey, traceID)
ctx = context.WithValue(ctx, RequestIDKey, requestID)
// Add response headers
w.Header().Set("X-Trace-ID", traceID)
w.Header().Set("X-Request-ID", requestID)
// Log request start
fmt.Printf("[%s] %s %s started\n", traceID, r.Method, r.URL.Path)
// Propagate context
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Timeout middleware
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Business logic handler
func handleUserRequest(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Extract trace info
traceID := ctx.Value(TraceIDKey).(string)
// Simulate calling another service
userData, err := fetchUserData(ctx, "user-123")
if err != nil {
fmt.Printf("[%s] Error fetching user: %v\n", traceID, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
fmt.Printf("[%s] Request completed successfully\n", traceID)
fmt.Fprintf(w, "User data: %s\n", userData)
}
func fetchUserData(ctx context.Context, userID string) (string, error) {
traceID := ctx.Value(TraceIDKey).(string)
fmt.Printf("[%s] Fetching user data for: %s\n", traceID, userID)
// Simulate database call
select {
case <-time.After(500 * time.Millisecond):
return fmt.Sprintf("Data for %s", userID), nil
case <-ctx.Done():
return "", ctx.Err()
}
}
func generateTraceID() string {
return fmt.Sprintf("trace-%d", time.Now().UnixNano())
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/user", handleUserRequest)
// Apply middleware
handler := TracingMiddleware(
TimeoutMiddleware(5 * time.Second)(mux),
)
fmt.Println("Server starting on :8080")
http.ListenAndServe(":8080", handler)
}
Client-Side: Propagating Context in HTTP Clients
package main
import (
"context"
"fmt"
"io"
"net/http"
"time"
)
// HTTPClient wraps http.Client with context propagation
type HTTPClient struct {
client *http.Client
}
func NewHTTPClient() *HTTPClient {
return &HTTPClient{
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// DoWithContext makes HTTP request with context propagation
func (c *HTTPClient) DoWithContext(ctx context.Context, method, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return "", err
}
// Propagate trace headers
if traceID := ctx.Value(TraceIDKey); traceID != nil {
req.Header.Set("X-Trace-ID", traceID.(string))
}
if requestID := ctx.Value(RequestIDKey); requestID != nil {
req.Header.Set("X-Request-ID", requestID.(string))
}
// Make request
resp, err := c.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}
// Example: Calling multiple services with context propagation
func orchestrateServices(ctx context.Context) error {
client := NewHTTPClient()
// Add trace context
ctx = context.WithValue(ctx, TraceIDKey, generateTraceID())
ctx = context.WithValue(ctx, RequestIDKey, generateRequestID())
traceID := ctx.Value(TraceIDKey).(string)
// Call service A
fmt.Printf("[%s] Calling Service A\n", traceID)
_, err := client.DoWithContext(ctx, "GET", "http://localhost:8080/user")
if err != nil {
return fmt.Errorf("service A failed: %w", err)
}
// Call service B (with shorter timeout)
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
fmt.Printf("[%s] Calling Service B\n", traceID)
_, err = client.DoWithContext(ctx, "GET", "http://localhost:8081/orders")
if err != nil {
return fmt.Errorf("service B failed: %w", err)
}
return nil
}
Database Context Propagation
package main
import (
"context"
"database/sql"
"fmt"
"time"
)
type UserRepository struct {
db *sql.DB
}
// QueryWithContext propagates context to database layer
func (r *UserRepository) GetUser(ctx context.Context, userID int) (*User, error) {
// Extract trace info for logging
traceID := ctx.Value(TraceIDKey)
fmt.Printf("[%s] DB Query: GetUser(%d)\n", traceID, userID)
// Context is propagated to DB - query will be cancelled if context is cancelled
query := "SELECT id, name, email FROM users WHERE id = ?"
row := r.db.QueryRowContext(ctx, query, userID)
var user User
err := row.Scan(&user.ID, &user.Name, &user.Email)
if err != nil {
return nil, err
}
return &user, nil
}
// Transaction with context
func (r *UserRepository) UpdateUserWithContext(ctx context.Context, user *User) error {
traceID := ctx.Value(TraceIDKey)
// Start transaction with context
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
fmt.Printf("[%s] DB Transaction: UpdateUser(%d)\n", traceID, user.ID)
// All operations in transaction respect context
_, err = tx.ExecContext(ctx,
"UPDATE users SET name = ?, email = ? WHERE id = ?",
user.Name, user.Email, user.ID,
)
if err != nil {
return err
}
// Check context before committing
if ctx.Err() != nil {
return ctx.Err()
}
return tx.Commit()
}
type User struct {
ID int
Name string
Email string
}
Distributed Tracing Pattern
Complete Tracing Implementation
package main
import (
"context"
"fmt"
"time"
)
// Span represents a single unit of work
type Span struct {
TraceID string
SpanID string
ParentID string
Name string
StartTime time.Time
EndTime time.Time
Tags map[string]string
}
// Tracer manages trace context
type Tracer struct {
spans []*Span
}
func NewTracer() *Tracer {
return &Tracer{
spans: make([]*Span, 0),
}
}
type spanKey struct{}
// StartSpan creates a new span and adds it to context
func (t *Tracer) StartSpan(ctx context.Context, name string) (context.Context, *Span) {
parentSpan, _ := ctx.Value(spanKey{}).(*Span)
span := &Span{
SpanID: generateSpanID(),
Name: name,
StartTime: time.Now(),
Tags: make(map[string]string),
}
if parentSpan != nil {
span.TraceID = parentSpan.TraceID
span.ParentID = parentSpan.SpanID
} else {
span.TraceID = generateTraceID()
}
fmt.Printf("[TRACE] Started span: %s (trace: %s, span: %s)\n",
name, span.TraceID, span.SpanID)
ctx = context.WithValue(ctx, spanKey{}, span)
return ctx, span
}
// FinishSpan completes a span
func (t *Tracer) FinishSpan(span *Span) {
span.EndTime = time.Now()
duration := span.EndTime.Sub(span.StartTime)
fmt.Printf("[TRACE] Finished span: %s (duration: %v)\n", span.Name, duration)
t.spans = append(t.spans, span)
}
// AddTag adds metadata to current span
func AddSpanTag(ctx context.Context, key, value string) {
if span, ok := ctx.Value(spanKey{}).(*Span); ok {
span.Tags[key] = value
}
}
// Example: Multi-service call with distributed tracing
func ProcessOrder(ctx context.Context, orderID string) error {
tracer := NewTracer()
// Start root span
ctx, rootSpan := tracer.StartSpan(ctx, "ProcessOrder")
defer tracer.FinishSpan(rootSpan)
AddSpanTag(ctx, "order.id", orderID)
// Validate order
if err := ValidateOrder(ctx, tracer, orderID); err != nil {
AddSpanTag(ctx, "error", err.Error())
return err
}
// Process payment
if err := ProcessPayment(ctx, tracer, orderID); err != nil {
AddSpanTag(ctx, "error", err.Error())
return err
}
// Ship order
if err := ShipOrder(ctx, tracer, orderID); err != nil {
AddSpanTag(ctx, "error", err.Error())
return err
}
return nil
}
func ValidateOrder(ctx context.Context, tracer *Tracer, orderID string) error {
ctx, span := tracer.StartSpan(ctx, "ValidateOrder")
defer tracer.FinishSpan(span)
AddSpanTag(ctx, "service", "validation")
// Simulate validation
time.Sleep(100 * time.Millisecond)
return nil
}
func ProcessPayment(ctx context.Context, tracer *Tracer, orderID string) error {
ctx, span := tracer.StartSpan(ctx, "ProcessPayment")
defer tracer.FinishSpan(span)
AddSpanTag(ctx, "service", "payment")
// Call payment gateway
if err := CallPaymentGateway(ctx, tracer, orderID); err != nil {
return err
}
return nil
}
func CallPaymentGateway(ctx context.Context, tracer *Tracer, orderID string) error {
ctx, span := tracer.StartSpan(ctx, "CallPaymentGateway")
defer tracer.FinishSpan(span)
AddSpanTag(ctx, "service", "external")
AddSpanTag(ctx, "gateway", "stripe")
// Simulate API call
time.Sleep(200 * time.Millisecond)
return nil
}
func ShipOrder(ctx context.Context, tracer *Tracer, orderID string) error {
ctx, span := tracer.StartSpan(ctx, "ShipOrder")
defer tracer.FinishSpan(span)
AddSpanTag(ctx, "service", "shipping")
// Simulate shipping
time.Sleep(150 * time.Millisecond)
return nil
}
func generateSpanID() string {
return fmt.Sprintf("span-%d", time.Now().UnixNano())
}
Deadline Propagation
Cascading Deadlines Across Services
package main
import (
"context"
"fmt"
"time"
)
// ServiceOrchestrator manages deadline propagation
type ServiceOrchestrator struct {
totalTimeout time.Duration
}
func NewServiceOrchestrator(timeout time.Duration) *ServiceOrchestrator {
return &ServiceOrchestrator{
totalTimeout: timeout,
}
}
func (s *ServiceOrchestrator) Execute(ctx context.Context) error {
// Set overall deadline
ctx, cancel := context.WithTimeout(ctx, s.totalTimeout)
defer cancel()
deadline, _ := ctx.Deadline()
remaining := time.Until(deadline)
fmt.Printf("Total deadline: %v remaining\n", remaining)
// Step 1: Use 30% of time budget
step1Duration := time.Duration(float64(remaining) * 0.3)
if err := s.executeStep1(ctx, step1Duration); err != nil {
return err
}
// Step 2: Use 40% of remaining time
remaining = time.Until(deadline)
step2Duration := time.Duration(float64(remaining) * 0.4)
if err := s.executeStep2(ctx, step2Duration); err != nil {
return err
}
// Step 3: Use remaining time
if err := s.executeStep3(ctx); err != nil {
return err
}
return nil
}
func (s *ServiceOrchestrator) executeStep1(ctx context.Context, budget time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, budget)
defer cancel()
fmt.Printf("Step 1: Budget %v\n", budget)
select {
case <-time.After(budget / 2):
fmt.Println("Step 1: Completed")
return nil
case <-ctx.Done():
return fmt.Errorf("step 1 timeout: %w", ctx.Err())
}
}
func (s *ServiceOrchestrator) executeStep2(ctx context.Context, budget time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, budget)
defer cancel()
fmt.Printf("Step 2: Budget %v\n", budget)
select {
case <-time.After(budget / 2):
fmt.Println("Step 2: Completed")
return nil
case <-ctx.Done():
return fmt.Errorf("step 2 timeout: %w", ctx.Err())
}
}
func (s *ServiceOrchestrator) executeStep3(ctx context.Context) error {
deadline, ok := ctx.Deadline()
if !ok {
return fmt.Errorf("no deadline set")
}
budget := time.Until(deadline)
fmt.Printf("Step 3: Budget %v\n", budget)
select {
case <-time.After(budget / 2):
fmt.Println("Step 3: Completed")
return nil
case <-ctx.Done():
return fmt.Errorf("step 3 timeout: %w", ctx.Err())
}
}
func main() {
orchestrator := NewServiceOrchestrator(5 * time.Second)
ctx := context.Background()
err := orchestrator.Execute(ctx)
if err != nil {
fmt.Printf("Error: %v\n", err)
}
}
Best Practices
1. Always Propagate Context
// BAD: Creating new context
func ServiceB() {
ctx := context.Background() // Loses parent context!
// ...
}
// GOOD: Accepting and propagating context
func ServiceB(ctx context.Context) {
// Context propagates through
}
2. Don’t Store Context in Structs
// BAD: Storing context in struct
type Service struct {
ctx context.Context // Anti-pattern!
}
// GOOD: Pass context as first parameter
type Service struct {
// No context field
}
func (s *Service) DoWork(ctx context.Context) {
// Context as parameter
}
3. Use Type-Safe Context Keys
// BAD: String keys can collide
ctx = context.WithValue(ctx, "user_id", 123)
// GOOD: Unexported type prevents collisions
type contextKey string
const userIDKey contextKey = "user_id"
ctx = context.WithValue(ctx, userIDKey, 123)
4. Handle Missing Context Values
// BAD: Panic on missing value
userID := ctx.Value(userIDKey).(string)
// GOOD: Safe extraction
userID, ok := ctx.Value(userIDKey).(string)
if !ok {
userID = "unknown"
}
Common Pitfalls
- Creating New Background Contexts: Breaks propagation chain
- Not Checking ctx.Done(): Ignoring cancellation signals
- Passing nil Context: Always use context.Background() or context.TODO()
- Storing Context: Context should flow through calls, not be stored
- Using Context for Optional Parameters: Use context only for request-scoped values
Performance Considerations
- Context propagation is lightweight (minimal overhead)
- Value lookups traverse parent chain (O(n) where n is depth)
- Keep context value chain shallow
- Use context values sparingly for hot paths
- Consider caching frequently accessed values
Testing Context Propagation
package main
import (
"context"
"testing"
"time"
)
func TestContextPropagation(t *testing.T) {
t.Run("trace ID propagates", func(t *testing.T) {
ctx := context.WithValue(context.Background(), TraceIDKey, "test-trace-123")
// Call service that should propagate context
err := ServiceA(ctx, "user-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify trace ID was propagated (check logs or traces)
})
t.Run("cancellation propagates", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error)
go func() {
done <- longRunningOperation(ctx)
}()
// Cancel context
cancel()
// Verify cancellation
select {
case err := <-done:
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
case <-time.After(1 * time.Second):
t.Fatal("operation did not respect cancellation")
}
})
}
func longRunningOperation(ctx context.Context) error {
select {
case <-time.After(10 * time.Second):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
Conclusion
Context propagation is essential for building observable, responsive distributed systems in Go. By threading context through your application, you enable:
- Distributed tracing across service boundaries
- Graceful cancellation of entire request trees
- Deadline enforcement for time-sensitive operations
- Request-scoped metadata without polluting signatures
Key Takeaways:
- Always accept context as first parameter
- Propagate context through all layers
- Use type-safe keys for context values
- Respect cancellation signals with ctx.Done()
- Set appropriate timeouts and deadlines
Next, explore the Go Memory Model to understand how Go guarantees visibility and ordering across goroutines.
Previous: Circuit Breaker Pattern Next: Go Memory Model Explained Series: Go Concurrency Patterns