Go 语言 sync.WaitGroup 深度解析

发布于:2025-06-09 ⋅ 阅读:(21) ⋅ 点赞:(0)

Go 语言 sync.WaitGroup 深度解析

sync.WaitGroup 是 Go 语言标准库中用于协调多个 goroutine 执行的重要同步原语,它提供了一种简单有效的方式让主 goroutine 等待一组工作 goroutine 完成任务。

核心概念

WaitGroup 结构

type WaitGroup struct {
    // 包含状态和信号量
    // ...
}

主要方法

func (wg *WaitGroup) Add(delta int)  // 增加等待的计数器值
func (wg *WaitGroup) Done()          // 减少计数器值(等同于 Add(-1))
func (wg *WaitGroup) Wait()          // 阻塞直到计数器归零

基本使用模式

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    
    // 启动3个goroutine
    for i := 1; i <= 3; i++ {
        wg.Add(1) // 增加一个等待计数
        go worker(i, &wg)
    }
    
    // 等待所有goroutine完成
    wg.Wait()
    fmt.Println("所有工作已完成!")
}

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done() // 确保任务结束时减少计数
    
    fmt.Printf("Worker %d 开始工作\n", id)
    time.Sleep(time.Duration(id) * time.Second) // 模拟工作耗时
    fmt.Printf("Worker %d 完成工作\n", id)
}

​输出​​:

Worker 1 开始工作
Worker 3 开始工作
Worker 2 开始工作
Worker 1 完成工作
Worker 2 完成工作
Worker 3 完成工作
所有工作已完成!

关键特性解析

1. Add 和 Done 必须配对

// 错误示例:Add 和 Done 数量不匹配
wg.Add(3)
for i := 0; i < 2; i++ {
    go func() {
        defer wg.Done()
        // ...
    }()
}
wg.Wait() // 死锁:永远等待第三个 Done

2. 不能在 Wait 之后调用 Add

var wg sync.WaitGroup

go func() {
    wg.Add(1)
    defer wg.Done()
    // ...
}()

wg.Wait()

// 错误:在 Wait 之后添加任务
wg.Add(1) // 可能导致 panic

3. 零值可用

// 无需初始化
var wg sync.WaitGroup

高级用法

1. 任务组嵌套

func main() {
    var mainWg sync.WaitGroup
    
    for groupID := 1; groupID <= 3; groupID++ {
        mainWg.Add(1)
        go func(id int) {
            defer mainWg.Done()
            runTaskGroup(id)
        }(groupID)
    }
    
    mainWg.Wait()
    fmt.Println("所有任务组完成")
}

func runTaskGroup(groupID int) {
    var groupWg sync.WaitGroup
    taskCount := groupID * 2
    
    for taskID := 1; taskID <= taskCount; taskID++ {
        groupWg.Add(1)
        go func(tid int) {
            defer groupWg.Done()
            time.Sleep(time.Duration(tid) * 500 * time.Millisecond)
            fmt.Printf("任务组 %d - 任务 %d 完成\n", groupID, tid)
        }(taskID)
    }
    
    groupWg.Wait()
    fmt.Printf("任务组 %d 完成\n", groupID)
}

2. 并发限制 + WaitGroup

func runConcurrentTasks(maxConcurrent int, tasks []func()) {
    sem := make(chan struct{}, maxConcurrent)
    var wg sync.WaitGroup
    
    for i, task := range tasks {
        sem <- struct{}{} // 获取信号量
        wg.Add(1)
        
        go func(idx int, t func()) {
            defer func() {
                <-sem      // 释放信号量
                wg.Done()  // 任务完成
            }()
            
            fmt.Printf("开始任务 %d\n", idx)
            t()
            fmt.Printf("完成任务 %d\n", idx)
        }(i, task)
    }
    
    wg.Wait()
    close(sem)
}

3. 等待超时控制

func waitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
    ch := make(chan struct{})
    
    go func() {
        defer close(ch)
        wg.Wait()
        ch <- struct{}{}
    }()
    
    select {
    case <-ch:
        return true // 正常完成
    case <-time.After(timeout):
        return false // 超时
    }
}

// 使用方式
var wg sync.WaitGroup
// ... 添加任务
if !waitWithTimeout(&wg, 5*time.Second) {
    fmt.Println("任务执行超时")
}

WaitGroup 内部实现原理

底层结构(简化版)

type WaitGroup struct {
    state1 [3]uint32 // 包含计数器、等待计数器和信号量的状态
}

工作流程

  1. ​Add(delta int)​​:

    • 原子操作增加计数器
    • 如果计数器变为负值,触发 panic
  2. ​Done()​​:

    • 调用 Add(-1)
    • 如果计数器归零,唤醒所有等待的 goroutine
  3. ​Wait()​​:

    • 等待直到计数器归零
    • 使用信号量避免忙等待
    • 多个 goroutine 可以同时调用 Wait

常见陷阱与解决方案

1. 指针传递问题

// 错误:传递 WaitGroup 副本
func worker(wg sync.WaitGroup) {
    defer wg.Done()
    // ...
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go worker(wg) // 传递副本,无效
    wg.Wait()     // 永远等待
}

// 正确:传递指针
func worker(wg *sync.WaitGroup) {
    defer wg.Done()
    // ...
}

2. 忘记调用 Done

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    
    go func() {
        // 忘记调用 Done!
        time.Sleep(time.Second)
        fmt.Println("完成任务")
    }()
    
    wg.Wait() // 死锁
}

​解决方案​​: 始终使用 defer wg.Done()

3. 提前调用 Wait

func main() {
    var wg sync.WaitGroup
    
    go func() {
        // 这里可能比 Add 调用晚执行
        wg.Add(1)
        defer wg.Done()
        // ...
    }()
    
    wg.Wait() // 可能过早退出
}

​解决方案​​: 在主 goroutine 中集中添加任务

func main() {
    var wg sync.WaitGroup
    tasks := []func(){...}
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t func()) {
            defer wg.Done()
            t()
        }(task)
    }
    
    wg.Wait()
}

WaitGroup 与其他同步原语对比

原语 用途 适用场景
WaitGroup 等待一组 goroutine 完成 批量任务、并行计算
Channel goroutine 间通信 数据传递、事件通知
Mutex 保护共享资源 临界区访问
Cond goroutine 条件等待 等待特定条件满足
Once 确保操作只执行一次 单例初始化
RWMutex 读写分离的互斥锁 读多写少的场景

最佳实践

  1. ​使用 defer 调用 Done​​:

    wg.Add(1)
    go func() {
        defer wg.Done() // 确保一定调用
        // ...
    }()
  2. ​主协程中初始化计数​​:

    tasks := getTasks()
    wg.Add(len(tasks)) // 一次性添加计数
    for _, task := range tasks {
        go process(task, &wg)
    }
    wg.Wait()
  3. ​避免在子协程中调用 Add​​:

    // 不推荐
    go func() {
        wg.Add(1)
        defer wg.Done()
        // ...
    }()
    
    // 推荐
    wg.Add(1)
    go func() {
        defer wg.Done()
        // ...
    }()
  4. ​使用 Context 处理超时​​:

    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    
    go func() {
        defer wg.Done()
        
        select {
        case <-time.After(10 * time.Second):
            // 长时间任务
        case <-ctx.Done():
            // 超时取消
            return
        }
    }()

性能考虑

sync.WaitGroup 的实现非常高效:

  • 使用原子操作进行计数
  • 基于信号量的等待机制避免忙等待
  • 零值即可安全使用
  • 无额外内存分配(栈分配)

在绝大多数场景下,WaitGroup 的性能开销可以忽略不计,可以放心使用。

实际应用案例

1. 并行文件处理

func processFiles(files []string) {
    var wg sync.WaitGroup
    
    for _, file := range files {
        wg.Add(1)
        go func(f string) {
            defer wg.Done()
            processFile(f)
        }(file)
    }
    
    wg.Wait()
}

2. 批量 API 请求

func fetchURLs(urls []string) map[string]string {
    results := make(map[string]string)
    var mu sync.Mutex
    var wg sync.WaitGroup
    
    for _, url := range urls {
        wg.Add(1)
        go func(u string) {
            defer wg.Done()
            
            resp, err := http.Get(u)
            if err != nil {
                return
            }
            defer resp.Body.Close()
            
            body, _ := io.ReadAll(resp.Body)
            
            mu.Lock()
            results[u] = string(body)
            mu.Unlock()
        }(url)
    }
    
    wg.Wait()
    return results
}

3. 分布式任务执行

func distributeTasks(tasks []Task, workers int) {
    taskCh := make(chan Task, len(tasks))
    var wg sync.WaitGroup
    
    // 创建工作池
    for i := 0; i < workers; i++ {
        wg.Add(1)
        go func(workerID int) {
            defer wg.Done()
            for task := range taskCh {
                processTask(workerID, task)
            }
        }(i)
    }
    
    // 分发任务
    for _, task := range tasks {
        taskCh <- task
    }
    close(taskCh)
    
    wg.Wait()
}

sync.WaitGroup 是 Go 语言并发编程中最基础且最强大的工具之一,掌握其正确用法对于编写高效可靠的并发程序至关重要。通过配合其他同步原语,可以构建出复杂的并发处理系统。