WaitGroup Pattern in Go
Go Concurrency Patterns Series: ← Mutex Patterns | Series Overview | Once Pattern → What is the WaitGroup Pattern? The WaitGroup pattern uses sync.WaitGroup to coordinate the completion of multiple goroutines. It acts as a counter that blocks until all registered goroutines have finished executing, making it perfect for implementing barriers and waiting for parallel tasks to complete. Key Operations: Add(n): Increment the counter by n Done(): Decrement the counter by 1 (usually called with defer) Wait(): Block until counter reaches zero Real-World Use Cases Parallel Processing: Wait for all workers to complete Batch Operations: Process multiple items concurrently Service Initialization: Wait for all services to start Data Collection: Gather results from multiple sources Cleanup Operations: Ensure all cleanup tasks finish Testing: Coordinate test goroutines Basic WaitGroup Usage package main import ( "fmt" "math/rand" "sync" "time" ) // Task represents work to be done type Task struct { ID int Name string } // processTask simulates processing a task func processTask(task Task, wg *sync.WaitGroup) { defer wg.Done() // Always call Done when goroutine finishes fmt.Printf("Starting task %d: %s\n", task.ID, task.Name) // Simulate work duration := time.Duration(rand.Intn(1000)) * time.Millisecond time.Sleep(duration) fmt.Printf("Completed task %d: %s (took %v)\n", task.ID, task.Name, duration) } func main() { tasks := []Task{ {1, "Process images"}, {2, "Send emails"}, {3, "Update database"}, {4, "Generate reports"}, {5, "Backup files"}, } var wg sync.WaitGroup fmt.Println("Starting parallel task processing...") // Start all tasks for _, task := range tasks { wg.Add(1) // Increment counter for each goroutine go processTask(task, &wg) } // Wait for all tasks to complete wg.Wait() fmt.Println("All tasks completed!") } WaitGroup with Error Handling package main import ( "fmt" "math/rand" "sync" "time" ) // Result represents the outcome of a task type Result struct { TaskID int Data interface{} Error error } // TaskProcessor handles tasks with error collection type TaskProcessor struct { wg sync.WaitGroup results chan Result errors []error mu sync.Mutex } // NewTaskProcessor creates a new task processor func NewTaskProcessor(bufferSize int) *TaskProcessor { return &TaskProcessor{ results: make(chan Result, bufferSize), } } // processTaskWithError simulates task processing that might fail func (tp *TaskProcessor) processTaskWithError(taskID int, data interface{}) { defer tp.wg.Done() fmt.Printf("Processing task %d\n", taskID) // Simulate work time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) // Simulate random failures if rand.Float32() < 0.3 { err := fmt.Errorf("task %d failed", taskID) tp.results <- Result{TaskID: taskID, Error: err} // Collect error tp.mu.Lock() tp.errors = append(tp.errors, err) tp.mu.Unlock() fmt.Printf("Task %d failed\n", taskID) return } // Success result := fmt.Sprintf("Result from task %d", taskID) tp.results <- Result{TaskID: taskID, Data: result} fmt.Printf("Task %d completed successfully\n", taskID) } // ProcessTasks processes multiple tasks and collects results func (tp *TaskProcessor) ProcessTasks(taskCount int) ([]Result, []error) { // Start all tasks for i := 1; i <= taskCount; i++ { tp.wg.Add(1) go tp.processTaskWithError(i, fmt.Sprintf("data-%d", i)) } // Close results channel when all tasks complete go func() { tp.wg.Wait() close(tp.results) }() // Collect results var results []Result for result := range tp.results { results = append(results, result) } tp.mu.Lock() errors := make([]error, len(tp.errors)) copy(errors, tp.errors) tp.mu.Unlock() return results, errors } func main() { processor := NewTaskProcessor(10) fmt.Println("Starting task processing with error handling...") results, errors := processor.ProcessTasks(8) fmt.Printf("\nProcessing complete!\n") fmt.Printf("Successful tasks: %d\n", len(results)-len(errors)) fmt.Printf("Failed tasks: %d\n", len(errors)) if len(errors) > 0 { fmt.Println("\nErrors:") for _, err := range errors { fmt.Printf(" - %v\n", err) } } fmt.Println("\nResults:") for _, result := range results { if result.Error == nil { fmt.Printf(" Task %d: %v\n", result.TaskID, result.Data) } } } Nested WaitGroups for Hierarchical Tasks package main import ( "fmt" "sync" "time" ) // Department represents a department with multiple teams type Department struct { Name string Teams []Team } // Team represents a team with multiple workers type Team struct { Name string Workers []string } // processDepartment processes all teams in a department func processDepartment(dept Department, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Department %s starting work\n", dept.Name) var teamWG sync.WaitGroup // Process all teams in parallel for _, team := range dept.Teams { teamWG.Add(1) go processTeam(team, &teamWG) } // Wait for all teams to complete teamWG.Wait() fmt.Printf("Department %s completed all work\n", dept.Name) } // processTeam processes all workers in a team func processTeam(team Team, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf(" Team %s starting work\n", team.Name) var workerWG sync.WaitGroup // Process all workers in parallel for _, worker := range team.Workers { workerWG.Add(1) go processWorker(worker, &workerWG) } // Wait for all workers to complete workerWG.Wait() fmt.Printf(" Team %s completed all work\n", team.Name) } // processWorker simulates worker processing func processWorker(worker string, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf(" Worker %s working...\n", worker) time.Sleep(time.Duration(100+rand.Intn(200)) * time.Millisecond) fmt.Printf(" Worker %s finished\n", worker) } func main() { departments := []Department{ { Name: "Engineering", Teams: []Team{ { Name: "Backend", Workers: []string{"Alice", "Bob", "Charlie"}, }, { Name: "Frontend", Workers: []string{"Diana", "Eve"}, }, }, }, { Name: "Marketing", Teams: []Team{ { Name: "Digital", Workers: []string{"Frank", "Grace"}, }, { Name: "Content", Workers: []string{"Henry", "Ivy", "Jack"}, }, }, }, } var deptWG sync.WaitGroup fmt.Println("Starting company-wide project...") // Process all departments in parallel for _, dept := range departments { deptWG.Add(1) go processDepartment(dept, &deptWG) } // Wait for all departments to complete deptWG.Wait() fmt.Println("Company-wide project completed!") } WaitGroup with Timeout package main import ( "context" "fmt" "sync" "time" ) // TimedTaskRunner runs tasks with timeout support type TimedTaskRunner struct { timeout time.Duration } // NewTimedTaskRunner creates a new timed task runner func NewTimedTaskRunner(timeout time.Duration) *TimedTaskRunner { return &TimedTaskRunner{timeout: timeout} } // RunWithTimeout runs tasks with a timeout func (ttr *TimedTaskRunner) RunWithTimeout(tasks []func()) error { ctx, cancel := context.WithTimeout(context.Background(), ttr.timeout) defer cancel() var wg sync.WaitGroup done := make(chan struct{}) // Start all tasks for i, task := range tasks { wg.Add(1) go func(taskID int, taskFunc func()) { defer wg.Done() fmt.Printf("Starting task %d\n", taskID) taskFunc() fmt.Printf("Completed task %d\n", taskID) }(i+1, task) } // Wait for completion in separate goroutine go func() { wg.Wait() close(done) }() // Wait for either completion or timeout select { case <-done: fmt.Println("All tasks completed successfully") return nil case <-ctx.Done(): fmt.Println("Tasks timed out") return ctx.Err() } } // simulateTask creates a task that takes a specific duration func simulateTask(duration time.Duration, name string) func() { return func() { fmt.Printf(" %s working for %v\n", name, duration) time.Sleep(duration) fmt.Printf(" %s finished\n", name) } } func main() { runner := NewTimedTaskRunner(2 * time.Second) // Test with tasks that complete within timeout fmt.Println("Test 1: Tasks that complete within timeout") tasks1 := []func(){ simulateTask(300*time.Millisecond, "Quick task 1"), simulateTask(500*time.Millisecond, "Quick task 2"), simulateTask(400*time.Millisecond, "Quick task 3"), } if err := runner.RunWithTimeout(tasks1); err != nil { fmt.Printf("Error: %v\n", err) } fmt.Println("\nTest 2: Tasks that exceed timeout") tasks2 := []func(){ simulateTask(800*time.Millisecond, "Slow task 1"), simulateTask(1500*time.Millisecond, "Slow task 2"), simulateTask(2000*time.Millisecond, "Very slow task"), } if err := runner.RunWithTimeout(tasks2); err != nil { fmt.Printf("Error: %v\n", err) } } Dynamic WaitGroup Management package main import ( "fmt" "sync" "time" ) // DynamicTaskManager manages tasks that can spawn other tasks type DynamicTaskManager struct { wg sync.WaitGroup taskChan chan func() quit chan struct{} active sync.WaitGroup } // NewDynamicTaskManager creates a new dynamic task manager func NewDynamicTaskManager() *DynamicTaskManager { return &DynamicTaskManager{ taskChan: make(chan func(), 100), quit: make(chan struct{}), } } // Start begins processing tasks func (dtm *DynamicTaskManager) Start() { go dtm.taskProcessor() } // taskProcessor processes tasks from the channel func (dtm *DynamicTaskManager) taskProcessor() { for { select { case task := <-dtm.taskChan: dtm.active.Add(1) go func() { defer dtm.active.Done() task() }() case <-dtm.quit: return } } } // AddTask adds a task to be processed func (dtm *DynamicTaskManager) AddTask(task func()) { select { case dtm.taskChan <- task: case <-dtm.quit: } } // Wait waits for all active tasks to complete func (dtm *DynamicTaskManager) Wait() { dtm.active.Wait() } // Stop stops the task manager func (dtm *DynamicTaskManager) Stop() { close(dtm.quit) dtm.Wait() } // recursiveTask demonstrates a task that spawns other tasks func recursiveTask(manager *DynamicTaskManager, depth int, maxDepth int, id string) func() { return func() { fmt.Printf("Task %s (depth %d) starting\n", id, depth) time.Sleep(100 * time.Millisecond) if depth < maxDepth { // Spawn child tasks for i := 0; i < 2; i++ { childID := fmt.Sprintf("%s.%d", id, i+1) manager.AddTask(recursiveTask(manager, depth+1, maxDepth, childID)) } } fmt.Printf("Task %s (depth %d) completed\n", id, depth) } } func main() { manager := NewDynamicTaskManager() manager.Start() defer manager.Stop() fmt.Println("Starting dynamic task processing...") // Add initial tasks that will spawn more tasks for i := 0; i < 3; i++ { taskID := fmt.Sprintf("root-%d", i+1) manager.AddTask(recursiveTask(manager, 0, 2, taskID)) } // Wait for all tasks (including dynamically created ones) to complete manager.Wait() fmt.Println("All tasks completed!") } Best Practices Always Use defer: Call Done() with defer to ensure it’s called even if panic occurs Add Before Starting: Call Add() before starting goroutines to avoid race conditions Don’t Reuse WaitGroups: Create new WaitGroup for each batch of operations Handle Panics: Use recover in goroutines to prevent panic from affecting WaitGroup Avoid Negative Counters: Don’t call Done() more times than Add() Use Timeouts: Combine with context for timeout handling Consider Alternatives: Use channels for complex coordination scenarios Common Pitfalls 1. Race Condition with Add/Done // Bad: Race condition func badExample() { var wg sync.WaitGroup for i := 0; i < 5; i++ { go func() { wg.Add(1) // Race: might be called after Wait() defer wg.Done() // do work }() } wg.Wait() // Might not wait for all goroutines } // Good: Add before starting goroutines func goodExample() { var wg sync.WaitGroup for i := 0; i < 5; i++ { wg.Add(1) // Add before starting goroutine go func() { defer wg.Done() // do work }() } wg.Wait() } 2. Forgetting to Call Done // Bad: Missing Done() call func badTask(wg *sync.WaitGroup) { // do work if someCondition { return // Forgot to call Done()! } wg.Done() } // Good: Always use defer func goodTask(wg *sync.WaitGroup) { defer wg.Done() // Always called // do work if someCondition { return // Done() still called } } Testing WaitGroup Patterns package main import ( "sync" "testing" "time" ) func TestWaitGroupCompletion(t *testing.T) { var wg sync.WaitGroup completed := make([]bool, 5) for i := 0; i < 5; i++ { wg.Add(1) go func(index int) { defer wg.Done() time.Sleep(10 * time.Millisecond) completed[index] = true }(i) } wg.Wait() // Verify all tasks completed for i, done := range completed { if !done { t.Errorf("Task %d did not complete", i) } } } func TestWaitGroupWithTimeout(t *testing.T) { var wg sync.WaitGroup done := make(chan struct{}) wg.Add(1) go func() { defer wg.Done() time.Sleep(50 * time.Millisecond) }() go func() { wg.Wait() close(done) }() select { case <-done: // Success case <-time.After(100 * time.Millisecond): t.Error("WaitGroup did not complete within timeout") } } The WaitGroup pattern is essential for coordinating goroutines in Go. It provides a simple yet powerful way to wait for multiple concurrent operations to complete, making it perfect for parallel processing, batch operations, and synchronization barriers. ...