Go Concurrency Patterns Series: ← Mutex Patterns | Series Overview | Once Pattern →
What is the WaitGroup Pattern?
The WaitGroup pattern uses sync.WaitGroup
to coordinate the completion of multiple goroutines. It acts as a counter that blocks until all registered goroutines have finished executing, making it perfect for implementing barriers and waiting for parallel tasks to complete.
Key Operations:
- Add(n): Increment the counter by n
- Done(): Decrement the counter by 1 (usually called with defer)
- Wait(): Block until counter reaches zero
Real-World Use Cases
- Parallel Processing: Wait for all workers to complete
- Batch Operations: Process multiple items concurrently
- Service Initialization: Wait for all services to start
- Data Collection: Gather results from multiple sources
- Cleanup Operations: Ensure all cleanup tasks finish
- Testing: Coordinate test goroutines
Basic WaitGroup Usage
package main
import (
"fmt"
"math/rand"
"sync"
"time"
)
// Task represents work to be done
type Task struct {
ID int
Name string
}
// processTask simulates processing a task
func processTask(task Task, wg *sync.WaitGroup) {
defer wg.Done() // Always call Done when goroutine finishes
fmt.Printf("Starting task %d: %s\n", task.ID, task.Name)
// Simulate work
duration := time.Duration(rand.Intn(1000)) * time.Millisecond
time.Sleep(duration)
fmt.Printf("Completed task %d: %s (took %v)\n", task.ID, task.Name, duration)
}
func main() {
tasks := []Task{
{1, "Process images"},
{2, "Send emails"},
{3, "Update database"},
{4, "Generate reports"},
{5, "Backup files"},
}
var wg sync.WaitGroup
fmt.Println("Starting parallel task processing...")
// Start all tasks
for _, task := range tasks {
wg.Add(1) // Increment counter for each goroutine
go processTask(task, &wg)
}
// Wait for all tasks to complete
wg.Wait()
fmt.Println("All tasks completed!")
}
WaitGroup with Error Handling
package main
import (
"fmt"
"math/rand"
"sync"
"time"
)
// Result represents the outcome of a task
type Result struct {
TaskID int
Data interface{}
Error error
}
// TaskProcessor handles tasks with error collection
type TaskProcessor struct {
wg sync.WaitGroup
results chan Result
errors []error
mu sync.Mutex
}
// NewTaskProcessor creates a new task processor
func NewTaskProcessor(bufferSize int) *TaskProcessor {
return &TaskProcessor{
results: make(chan Result, bufferSize),
}
}
// processTaskWithError simulates task processing that might fail
func (tp *TaskProcessor) processTaskWithError(taskID int, data interface{}) {
defer tp.wg.Done()
fmt.Printf("Processing task %d\n", taskID)
// Simulate work
time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
// Simulate random failures
if rand.Float32() < 0.3 {
err := fmt.Errorf("task %d failed", taskID)
tp.results <- Result{TaskID: taskID, Error: err}
// Collect error
tp.mu.Lock()
tp.errors = append(tp.errors, err)
tp.mu.Unlock()
fmt.Printf("Task %d failed\n", taskID)
return
}
// Success
result := fmt.Sprintf("Result from task %d", taskID)
tp.results <- Result{TaskID: taskID, Data: result}
fmt.Printf("Task %d completed successfully\n", taskID)
}
// ProcessTasks processes multiple tasks and collects results
func (tp *TaskProcessor) ProcessTasks(taskCount int) ([]Result, []error) {
// Start all tasks
for i := 1; i <= taskCount; i++ {
tp.wg.Add(1)
go tp.processTaskWithError(i, fmt.Sprintf("data-%d", i))
}
// Close results channel when all tasks complete
go func() {
tp.wg.Wait()
close(tp.results)
}()
// Collect results
var results []Result
for result := range tp.results {
results = append(results, result)
}
tp.mu.Lock()
errors := make([]error, len(tp.errors))
copy(errors, tp.errors)
tp.mu.Unlock()
return results, errors
}
func main() {
processor := NewTaskProcessor(10)
fmt.Println("Starting task processing with error handling...")
results, errors := processor.ProcessTasks(8)
fmt.Printf("\nProcessing complete!\n")
fmt.Printf("Successful tasks: %d\n", len(results)-len(errors))
fmt.Printf("Failed tasks: %d\n", len(errors))
if len(errors) > 0 {
fmt.Println("\nErrors:")
for _, err := range errors {
fmt.Printf(" - %v\n", err)
}
}
fmt.Println("\nResults:")
for _, result := range results {
if result.Error == nil {
fmt.Printf(" Task %d: %v\n", result.TaskID, result.Data)
}
}
}
Nested WaitGroups for Hierarchical Tasks
package main
import (
"fmt"
"sync"
"time"
)
// Department represents a department with multiple teams
type Department struct {
Name string
Teams []Team
}
// Team represents a team with multiple workers
type Team struct {
Name string
Workers []string
}
// processDepartment processes all teams in a department
func processDepartment(dept Department, wg *sync.WaitGroup) {
defer wg.Done()
fmt.Printf("Department %s starting work\n", dept.Name)
var teamWG sync.WaitGroup
// Process all teams in parallel
for _, team := range dept.Teams {
teamWG.Add(1)
go processTeam(team, &teamWG)
}
// Wait for all teams to complete
teamWG.Wait()
fmt.Printf("Department %s completed all work\n", dept.Name)
}
// processTeam processes all workers in a team
func processTeam(team Team, wg *sync.WaitGroup) {
defer wg.Done()
fmt.Printf(" Team %s starting work\n", team.Name)
var workerWG sync.WaitGroup
// Process all workers in parallel
for _, worker := range team.Workers {
workerWG.Add(1)
go processWorker(worker, &workerWG)
}
// Wait for all workers to complete
workerWG.Wait()
fmt.Printf(" Team %s completed all work\n", team.Name)
}
// processWorker simulates worker processing
func processWorker(worker string, wg *sync.WaitGroup) {
defer wg.Done()
fmt.Printf(" Worker %s working...\n", worker)
time.Sleep(time.Duration(100+rand.Intn(200)) * time.Millisecond)
fmt.Printf(" Worker %s finished\n", worker)
}
func main() {
departments := []Department{
{
Name: "Engineering",
Teams: []Team{
{
Name: "Backend",
Workers: []string{"Alice", "Bob", "Charlie"},
},
{
Name: "Frontend",
Workers: []string{"Diana", "Eve"},
},
},
},
{
Name: "Marketing",
Teams: []Team{
{
Name: "Digital",
Workers: []string{"Frank", "Grace"},
},
{
Name: "Content",
Workers: []string{"Henry", "Ivy", "Jack"},
},
},
},
}
var deptWG sync.WaitGroup
fmt.Println("Starting company-wide project...")
// Process all departments in parallel
for _, dept := range departments {
deptWG.Add(1)
go processDepartment(dept, &deptWG)
}
// Wait for all departments to complete
deptWG.Wait()
fmt.Println("Company-wide project completed!")
}
WaitGroup with Timeout
package main
import (
"context"
"fmt"
"sync"
"time"
)
// TimedTaskRunner runs tasks with timeout support
type TimedTaskRunner struct {
timeout time.Duration
}
// NewTimedTaskRunner creates a new timed task runner
func NewTimedTaskRunner(timeout time.Duration) *TimedTaskRunner {
return &TimedTaskRunner{timeout: timeout}
}
// RunWithTimeout runs tasks with a timeout
func (ttr *TimedTaskRunner) RunWithTimeout(tasks []func()) error {
ctx, cancel := context.WithTimeout(context.Background(), ttr.timeout)
defer cancel()
var wg sync.WaitGroup
done := make(chan struct{})
// Start all tasks
for i, task := range tasks {
wg.Add(1)
go func(taskID int, taskFunc func()) {
defer wg.Done()
fmt.Printf("Starting task %d\n", taskID)
taskFunc()
fmt.Printf("Completed task %d\n", taskID)
}(i+1, task)
}
// Wait for completion in separate goroutine
go func() {
wg.Wait()
close(done)
}()
// Wait for either completion or timeout
select {
case <-done:
fmt.Println("All tasks completed successfully")
return nil
case <-ctx.Done():
fmt.Println("Tasks timed out")
return ctx.Err()
}
}
// simulateTask creates a task that takes a specific duration
func simulateTask(duration time.Duration, name string) func() {
return func() {
fmt.Printf(" %s working for %v\n", name, duration)
time.Sleep(duration)
fmt.Printf(" %s finished\n", name)
}
}
func main() {
runner := NewTimedTaskRunner(2 * time.Second)
// Test with tasks that complete within timeout
fmt.Println("Test 1: Tasks that complete within timeout")
tasks1 := []func(){
simulateTask(300*time.Millisecond, "Quick task 1"),
simulateTask(500*time.Millisecond, "Quick task 2"),
simulateTask(400*time.Millisecond, "Quick task 3"),
}
if err := runner.RunWithTimeout(tasks1); err != nil {
fmt.Printf("Error: %v\n", err)
}
fmt.Println("\nTest 2: Tasks that exceed timeout")
tasks2 := []func(){
simulateTask(800*time.Millisecond, "Slow task 1"),
simulateTask(1500*time.Millisecond, "Slow task 2"),
simulateTask(2000*time.Millisecond, "Very slow task"),
}
if err := runner.RunWithTimeout(tasks2); err != nil {
fmt.Printf("Error: %v\n", err)
}
}
Dynamic WaitGroup Management
package main
import (
"fmt"
"sync"
"time"
)
// DynamicTaskManager manages tasks that can spawn other tasks
type DynamicTaskManager struct {
wg sync.WaitGroup
taskChan chan func()
quit chan struct{}
active sync.WaitGroup
}
// NewDynamicTaskManager creates a new dynamic task manager
func NewDynamicTaskManager() *DynamicTaskManager {
return &DynamicTaskManager{
taskChan: make(chan func(), 100),
quit: make(chan struct{}),
}
}
// Start begins processing tasks
func (dtm *DynamicTaskManager) Start() {
go dtm.taskProcessor()
}
// taskProcessor processes tasks from the channel
func (dtm *DynamicTaskManager) taskProcessor() {
for {
select {
case task := <-dtm.taskChan:
dtm.active.Add(1)
go func() {
defer dtm.active.Done()
task()
}()
case <-dtm.quit:
return
}
}
}
// AddTask adds a task to be processed
func (dtm *DynamicTaskManager) AddTask(task func()) {
select {
case dtm.taskChan <- task:
case <-dtm.quit:
}
}
// Wait waits for all active tasks to complete
func (dtm *DynamicTaskManager) Wait() {
dtm.active.Wait()
}
// Stop stops the task manager
func (dtm *DynamicTaskManager) Stop() {
close(dtm.quit)
dtm.Wait()
}
// recursiveTask demonstrates a task that spawns other tasks
func recursiveTask(manager *DynamicTaskManager, depth int, maxDepth int, id string) func() {
return func() {
fmt.Printf("Task %s (depth %d) starting\n", id, depth)
time.Sleep(100 * time.Millisecond)
if depth < maxDepth {
// Spawn child tasks
for i := 0; i < 2; i++ {
childID := fmt.Sprintf("%s.%d", id, i+1)
manager.AddTask(recursiveTask(manager, depth+1, maxDepth, childID))
}
}
fmt.Printf("Task %s (depth %d) completed\n", id, depth)
}
}
func main() {
manager := NewDynamicTaskManager()
manager.Start()
defer manager.Stop()
fmt.Println("Starting dynamic task processing...")
// Add initial tasks that will spawn more tasks
for i := 0; i < 3; i++ {
taskID := fmt.Sprintf("root-%d", i+1)
manager.AddTask(recursiveTask(manager, 0, 2, taskID))
}
// Wait for all tasks (including dynamically created ones) to complete
manager.Wait()
fmt.Println("All tasks completed!")
}
Best Practices
- Always Use defer: Call
Done()
with defer to ensure it’s called even if panic occurs - Add Before Starting: Call
Add()
before starting goroutines to avoid race conditions - Don’t Reuse WaitGroups: Create new WaitGroup for each batch of operations
- Handle Panics: Use recover in goroutines to prevent panic from affecting WaitGroup
- Avoid Negative Counters: Don’t call
Done()
more times thanAdd()
- Use Timeouts: Combine with context for timeout handling
- Consider Alternatives: Use channels for complex coordination scenarios
Common Pitfalls
1. Race Condition with Add/Done
// ❌ Bad: Race condition
func badExample() {
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
go func() {
wg.Add(1) // Race: might be called after Wait()
defer wg.Done()
// do work
}()
}
wg.Wait() // Might not wait for all goroutines
}
// ✅ Good: Add before starting goroutines
func goodExample() {
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1) // Add before starting goroutine
go func() {
defer wg.Done()
// do work
}()
}
wg.Wait()
}
2. Forgetting to Call Done
// ❌ Bad: Missing Done() call
func badTask(wg *sync.WaitGroup) {
// do work
if someCondition {
return // Forgot to call Done()!
}
wg.Done()
}
// ✅ Good: Always use defer
func goodTask(wg *sync.WaitGroup) {
defer wg.Done() // Always called
// do work
if someCondition {
return // Done() still called
}
}
Testing WaitGroup Patterns
package main
import (
"sync"
"testing"
"time"
)
func TestWaitGroupCompletion(t *testing.T) {
var wg sync.WaitGroup
completed := make([]bool, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
time.Sleep(10 * time.Millisecond)
completed[index] = true
}(i)
}
wg.Wait()
// Verify all tasks completed
for i, done := range completed {
if !done {
t.Errorf("Task %d did not complete", i)
}
}
}
func TestWaitGroupWithTimeout(t *testing.T) {
var wg sync.WaitGroup
done := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(50 * time.Millisecond)
}()
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success
case <-time.After(100 * time.Millisecond):
t.Error("WaitGroup did not complete within timeout")
}
}
The WaitGroup pattern is essential for coordinating goroutines in Go. It provides a simple yet powerful way to wait for multiple concurrent operations to complete, making it perfect for parallel processing, batch operations, and synchronization barriers.
Next: Learn about Once Pattern for ensuring code executes exactly once.