Go Concurrency Patterns Series: ← Mutex Patterns | Series Overview | Once Pattern →


What is the WaitGroup Pattern?

The WaitGroup pattern uses sync.WaitGroup to coordinate the completion of multiple goroutines. It acts as a counter that blocks until all registered goroutines have finished executing, making it perfect for implementing barriers and waiting for parallel tasks to complete.

Key Operations:

  • Add(n): Increment the counter by n
  • Done(): Decrement the counter by 1 (usually called with defer)
  • Wait(): Block until counter reaches zero

Real-World Use Cases

  • Parallel Processing: Wait for all workers to complete
  • Batch Operations: Process multiple items concurrently
  • Service Initialization: Wait for all services to start
  • Data Collection: Gather results from multiple sources
  • Cleanup Operations: Ensure all cleanup tasks finish
  • Testing: Coordinate test goroutines

Basic WaitGroup Usage

package main

import (
    "fmt"
    "math/rand"
    "sync"
    "time"
)

// Task represents work to be done
type Task struct {
    ID   int
    Name string
}

// processTask simulates processing a task
func processTask(task Task, wg *sync.WaitGroup) {
    defer wg.Done() // Always call Done when goroutine finishes
    
    fmt.Printf("Starting task %d: %s\n", task.ID, task.Name)
    
    // Simulate work
    duration := time.Duration(rand.Intn(1000)) * time.Millisecond
    time.Sleep(duration)
    
    fmt.Printf("Completed task %d: %s (took %v)\n", task.ID, task.Name, duration)
}

func main() {
    tasks := []Task{
        {1, "Process images"},
        {2, "Send emails"},
        {3, "Update database"},
        {4, "Generate reports"},
        {5, "Backup files"},
    }
    
    var wg sync.WaitGroup
    
    fmt.Println("Starting parallel task processing...")
    
    // Start all tasks
    for _, task := range tasks {
        wg.Add(1) // Increment counter for each goroutine
        go processTask(task, &wg)
    }
    
    // Wait for all tasks to complete
    wg.Wait()
    
    fmt.Println("All tasks completed!")
}

WaitGroup with Error Handling

package main

import (
    "fmt"
    "math/rand"
    "sync"
    "time"
)

// Result represents the outcome of a task
type Result struct {
    TaskID int
    Data   interface{}
    Error  error
}

// TaskProcessor handles tasks with error collection
type TaskProcessor struct {
    wg      sync.WaitGroup
    results chan Result
    errors  []error
    mu      sync.Mutex
}

// NewTaskProcessor creates a new task processor
func NewTaskProcessor(bufferSize int) *TaskProcessor {
    return &TaskProcessor{
        results: make(chan Result, bufferSize),
    }
}

// processTaskWithError simulates task processing that might fail
func (tp *TaskProcessor) processTaskWithError(taskID int, data interface{}) {
    defer tp.wg.Done()
    
    fmt.Printf("Processing task %d\n", taskID)
    
    // Simulate work
    time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
    
    // Simulate random failures
    if rand.Float32() < 0.3 {
        err := fmt.Errorf("task %d failed", taskID)
        tp.results <- Result{TaskID: taskID, Error: err}
        
        // Collect error
        tp.mu.Lock()
        tp.errors = append(tp.errors, err)
        tp.mu.Unlock()
        
        fmt.Printf("Task %d failed\n", taskID)
        return
    }
    
    // Success
    result := fmt.Sprintf("Result from task %d", taskID)
    tp.results <- Result{TaskID: taskID, Data: result}
    fmt.Printf("Task %d completed successfully\n", taskID)
}

// ProcessTasks processes multiple tasks and collects results
func (tp *TaskProcessor) ProcessTasks(taskCount int) ([]Result, []error) {
    // Start all tasks
    for i := 1; i <= taskCount; i++ {
        tp.wg.Add(1)
        go tp.processTaskWithError(i, fmt.Sprintf("data-%d", i))
    }
    
    // Close results channel when all tasks complete
    go func() {
        tp.wg.Wait()
        close(tp.results)
    }()
    
    // Collect results
    var results []Result
    for result := range tp.results {
        results = append(results, result)
    }
    
    tp.mu.Lock()
    errors := make([]error, len(tp.errors))
    copy(errors, tp.errors)
    tp.mu.Unlock()
    
    return results, errors
}

func main() {
    processor := NewTaskProcessor(10)
    
    fmt.Println("Starting task processing with error handling...")
    
    results, errors := processor.ProcessTasks(8)
    
    fmt.Printf("\nProcessing complete!\n")
    fmt.Printf("Successful tasks: %d\n", len(results)-len(errors))
    fmt.Printf("Failed tasks: %d\n", len(errors))
    
    if len(errors) > 0 {
        fmt.Println("\nErrors:")
        for _, err := range errors {
            fmt.Printf("  - %v\n", err)
        }
    }
    
    fmt.Println("\nResults:")
    for _, result := range results {
        if result.Error == nil {
            fmt.Printf("  Task %d: %v\n", result.TaskID, result.Data)
        }
    }
}

Nested WaitGroups for Hierarchical Tasks

package main

import (
    "fmt"
    "sync"
    "time"
)

// Department represents a department with multiple teams
type Department struct {
    Name  string
    Teams []Team
}

// Team represents a team with multiple workers
type Team struct {
    Name    string
    Workers []string
}

// processDepartment processes all teams in a department
func processDepartment(dept Department, wg *sync.WaitGroup) {
    defer wg.Done()
    
    fmt.Printf("Department %s starting work\n", dept.Name)
    
    var teamWG sync.WaitGroup
    
    // Process all teams in parallel
    for _, team := range dept.Teams {
        teamWG.Add(1)
        go processTeam(team, &teamWG)
    }
    
    // Wait for all teams to complete
    teamWG.Wait()
    
    fmt.Printf("Department %s completed all work\n", dept.Name)
}

// processTeam processes all workers in a team
func processTeam(team Team, wg *sync.WaitGroup) {
    defer wg.Done()
    
    fmt.Printf("  Team %s starting work\n", team.Name)
    
    var workerWG sync.WaitGroup
    
    // Process all workers in parallel
    for _, worker := range team.Workers {
        workerWG.Add(1)
        go processWorker(worker, &workerWG)
    }
    
    // Wait for all workers to complete
    workerWG.Wait()
    
    fmt.Printf("  Team %s completed all work\n", team.Name)
}

// processWorker simulates worker processing
func processWorker(worker string, wg *sync.WaitGroup) {
    defer wg.Done()
    
    fmt.Printf("    Worker %s working...\n", worker)
    time.Sleep(time.Duration(100+rand.Intn(200)) * time.Millisecond)
    fmt.Printf("    Worker %s finished\n", worker)
}

func main() {
    departments := []Department{
        {
            Name: "Engineering",
            Teams: []Team{
                {
                    Name:    "Backend",
                    Workers: []string{"Alice", "Bob", "Charlie"},
                },
                {
                    Name:    "Frontend",
                    Workers: []string{"Diana", "Eve"},
                },
            },
        },
        {
            Name: "Marketing",
            Teams: []Team{
                {
                    Name:    "Digital",
                    Workers: []string{"Frank", "Grace"},
                },
                {
                    Name:    "Content",
                    Workers: []string{"Henry", "Ivy", "Jack"},
                },
            },
        },
    }
    
    var deptWG sync.WaitGroup
    
    fmt.Println("Starting company-wide project...")
    
    // Process all departments in parallel
    for _, dept := range departments {
        deptWG.Add(1)
        go processDepartment(dept, &deptWG)
    }
    
    // Wait for all departments to complete
    deptWG.Wait()
    
    fmt.Println("Company-wide project completed!")
}

WaitGroup with Timeout

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// TimedTaskRunner runs tasks with timeout support
type TimedTaskRunner struct {
    timeout time.Duration
}

// NewTimedTaskRunner creates a new timed task runner
func NewTimedTaskRunner(timeout time.Duration) *TimedTaskRunner {
    return &TimedTaskRunner{timeout: timeout}
}

// RunWithTimeout runs tasks with a timeout
func (ttr *TimedTaskRunner) RunWithTimeout(tasks []func()) error {
    ctx, cancel := context.WithTimeout(context.Background(), ttr.timeout)
    defer cancel()
    
    var wg sync.WaitGroup
    done := make(chan struct{})
    
    // Start all tasks
    for i, task := range tasks {
        wg.Add(1)
        go func(taskID int, taskFunc func()) {
            defer wg.Done()
            fmt.Printf("Starting task %d\n", taskID)
            taskFunc()
            fmt.Printf("Completed task %d\n", taskID)
        }(i+1, task)
    }
    
    // Wait for completion in separate goroutine
    go func() {
        wg.Wait()
        close(done)
    }()
    
    // Wait for either completion or timeout
    select {
    case <-done:
        fmt.Println("All tasks completed successfully")
        return nil
    case <-ctx.Done():
        fmt.Println("Tasks timed out")
        return ctx.Err()
    }
}

// simulateTask creates a task that takes a specific duration
func simulateTask(duration time.Duration, name string) func() {
    return func() {
        fmt.Printf("  %s working for %v\n", name, duration)
        time.Sleep(duration)
        fmt.Printf("  %s finished\n", name)
    }
}

func main() {
    runner := NewTimedTaskRunner(2 * time.Second)
    
    // Test with tasks that complete within timeout
    fmt.Println("Test 1: Tasks that complete within timeout")
    tasks1 := []func(){
        simulateTask(300*time.Millisecond, "Quick task 1"),
        simulateTask(500*time.Millisecond, "Quick task 2"),
        simulateTask(400*time.Millisecond, "Quick task 3"),
    }
    
    if err := runner.RunWithTimeout(tasks1); err != nil {
        fmt.Printf("Error: %v\n", err)
    }
    
    fmt.Println("\nTest 2: Tasks that exceed timeout")
    tasks2 := []func(){
        simulateTask(800*time.Millisecond, "Slow task 1"),
        simulateTask(1500*time.Millisecond, "Slow task 2"),
        simulateTask(2000*time.Millisecond, "Very slow task"),
    }
    
    if err := runner.RunWithTimeout(tasks2); err != nil {
        fmt.Printf("Error: %v\n", err)
    }
}

Dynamic WaitGroup Management

package main

import (
    "fmt"
    "sync"
    "time"
)

// DynamicTaskManager manages tasks that can spawn other tasks
type DynamicTaskManager struct {
    wg       sync.WaitGroup
    taskChan chan func()
    quit     chan struct{}
    active   sync.WaitGroup
}

// NewDynamicTaskManager creates a new dynamic task manager
func NewDynamicTaskManager() *DynamicTaskManager {
    return &DynamicTaskManager{
        taskChan: make(chan func(), 100),
        quit:     make(chan struct{}),
    }
}

// Start begins processing tasks
func (dtm *DynamicTaskManager) Start() {
    go dtm.taskProcessor()
}

// taskProcessor processes tasks from the channel
func (dtm *DynamicTaskManager) taskProcessor() {
    for {
        select {
        case task := <-dtm.taskChan:
            dtm.active.Add(1)
            go func() {
                defer dtm.active.Done()
                task()
            }()
        case <-dtm.quit:
            return
        }
    }
}

// AddTask adds a task to be processed
func (dtm *DynamicTaskManager) AddTask(task func()) {
    select {
    case dtm.taskChan <- task:
    case <-dtm.quit:
    }
}

// Wait waits for all active tasks to complete
func (dtm *DynamicTaskManager) Wait() {
    dtm.active.Wait()
}

// Stop stops the task manager
func (dtm *DynamicTaskManager) Stop() {
    close(dtm.quit)
    dtm.Wait()
}

// recursiveTask demonstrates a task that spawns other tasks
func recursiveTask(manager *DynamicTaskManager, depth int, maxDepth int, id string) func() {
    return func() {
        fmt.Printf("Task %s (depth %d) starting\n", id, depth)
        time.Sleep(100 * time.Millisecond)
        
        if depth < maxDepth {
            // Spawn child tasks
            for i := 0; i < 2; i++ {
                childID := fmt.Sprintf("%s.%d", id, i+1)
                manager.AddTask(recursiveTask(manager, depth+1, maxDepth, childID))
            }
        }
        
        fmt.Printf("Task %s (depth %d) completed\n", id, depth)
    }
}

func main() {
    manager := NewDynamicTaskManager()
    manager.Start()
    defer manager.Stop()
    
    fmt.Println("Starting dynamic task processing...")
    
    // Add initial tasks that will spawn more tasks
    for i := 0; i < 3; i++ {
        taskID := fmt.Sprintf("root-%d", i+1)
        manager.AddTask(recursiveTask(manager, 0, 2, taskID))
    }
    
    // Wait for all tasks (including dynamically created ones) to complete
    manager.Wait()
    
    fmt.Println("All tasks completed!")
}

Best Practices

  1. Always Use defer: Call Done() with defer to ensure it’s called even if panic occurs
  2. Add Before Starting: Call Add() before starting goroutines to avoid race conditions
  3. Don’t Reuse WaitGroups: Create new WaitGroup for each batch of operations
  4. Handle Panics: Use recover in goroutines to prevent panic from affecting WaitGroup
  5. Avoid Negative Counters: Don’t call Done() more times than Add()
  6. Use Timeouts: Combine with context for timeout handling
  7. Consider Alternatives: Use channels for complex coordination scenarios

Common Pitfalls

1. Race Condition with Add/Done

// ❌ Bad: Race condition
func badExample() {
    var wg sync.WaitGroup
    
    for i := 0; i < 5; i++ {
        go func() {
            wg.Add(1) // Race: might be called after Wait()
            defer wg.Done()
            // do work
        }()
    }
    
    wg.Wait() // Might not wait for all goroutines
}

// ✅ Good: Add before starting goroutines
func goodExample() {
    var wg sync.WaitGroup
    
    for i := 0; i < 5; i++ {
        wg.Add(1) // Add before starting goroutine
        go func() {
            defer wg.Done()
            // do work
        }()
    }
    
    wg.Wait()
}

2. Forgetting to Call Done

// ❌ Bad: Missing Done() call
func badTask(wg *sync.WaitGroup) {
    // do work
    if someCondition {
        return // Forgot to call Done()!
    }
    wg.Done()
}

// ✅ Good: Always use defer
func goodTask(wg *sync.WaitGroup) {
    defer wg.Done() // Always called
    // do work
    if someCondition {
        return // Done() still called
    }
}

Testing WaitGroup Patterns

package main

import (
    "sync"
    "testing"
    "time"
)

func TestWaitGroupCompletion(t *testing.T) {
    var wg sync.WaitGroup
    completed := make([]bool, 5)
    
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(index int) {
            defer wg.Done()
            time.Sleep(10 * time.Millisecond)
            completed[index] = true
        }(i)
    }
    
    wg.Wait()
    
    // Verify all tasks completed
    for i, done := range completed {
        if !done {
            t.Errorf("Task %d did not complete", i)
        }
    }
}

func TestWaitGroupWithTimeout(t *testing.T) {
    var wg sync.WaitGroup
    done := make(chan struct{})
    
    wg.Add(1)
    go func() {
        defer wg.Done()
        time.Sleep(50 * time.Millisecond)
    }()
    
    go func() {
        wg.Wait()
        close(done)
    }()
    
    select {
    case <-done:
        // Success
    case <-time.After(100 * time.Millisecond):
        t.Error("WaitGroup did not complete within timeout")
    }
}

The WaitGroup pattern is essential for coordinating goroutines in Go. It provides a simple yet powerful way to wait for multiple concurrent operations to complete, making it perfect for parallel processing, batch operations, and synchronization barriers.


Next: Learn about Once Pattern for ensuring code executes exactly once.