Go语言设计(6)-WaitGroup

Go语言设计(6)-WaitGroup

Date
Feb 16, 2022
Tags
Go
Concurrency

基础使用

Go语言中除了使用channel和互斥锁进行两个并发程序的同步外,还可以使用等待组(sync.WaitGroup)进行多个任务的同步, 等待组可以保证在并发环境中完成指定数量任务
每个 sync.WaitGroup 值在内部维护着一个计数,此计数的初始默认值为零。
  • (wg * WaitGroup) Add(delta int) 等待组的计数器 +1
  • (wg * WaitGroup) Done() 等待组的计数器 -1
  • (wg * WaitGroup) Wait()当等待组计数器不等于 0 时阻塞直到变 0。
 
对于一个可寻址的 sync.WaitGroup 值 wg:
  • 我们可以使用方法调用 wg.Add(delta) 来改变值 wg 维护的计数。
  • 方法调用 wg.Done() wg.Add(-1) 是完全等价的。
  • 如果一个 wg.Add(delta) 或者 wg.Done() 调用将 wg 维护的计数更改成一个负数,一个恐慌将产生。
  • 当一个协程调用了 wg.Wait() 时,
    • 如果此时 wg 维护的计数为零,则此 wg.Wait() 此操作为一个空操作(noop);
    • 否则(计数为一个正整数),此协程将进入阻塞状态。当以后其它某个协程将此计数更改至 0 时(一般通过调用 wg.Done()),此协程将重新进入运行状态(即 wg.Wait() 将返回)。
等待组内部拥有一个计数器,计数器的值可以通过方法调用实现计数器的增加和减少。当我们添加了 N 个并发任务进行工作时,就将等待组的计数器值增加 N。每个任务完成时,这个值减 1。同时,在另外一个 goroutine 中等待这个等待组的计数器值为 0 时,表示所有任务已经完成。
package main

import (
	"fmt"
	"net/http"
	"sync"
)

func main() {

	wg := sync.WaitGroup{}
	urls := []string{
		"http://www.baidu.com",
		"https://www.qiuniu.com",
		"https://www.golangtc.com",
	}

	for _, url := range urls {
		// 每一个任务增加1 - 协程之前保证有1堵塞
		wg.Add(1)

		// 开启一个并发
		go func(url string) {
			defer wg.Done()

			_, err := http.Get(url)
			fmt.Println(url, err)
		}(url)
	}

	wg.Wait()
	fmt.Println("over")
}

源码解析

github地址 , 下面通过复习一些前置知识,来理解WaitGroup

前置知识

信号量

信号量是一种保护共享资源的机制,用于解决多线程同步的问题。信号量S是具有非负整数值的全局变量,只能由两种特殊的操作来处理:
  • P(s): 如果S是非零,那么P将s减1,并且立即返回.如果s为零,那么就挂起这个线程,直到变为非零,等到另外一个执行V(s)操作线程来唤醒该线程。在唤醒之后P将s减1,并将控制返回给调用者
  • V(s):V操作将s加1, 如果有任务线程阻塞在P操作等待S变成非零,那么V操作会唤醒这些线程中一个,然后将该线程将S减1,完成它的P操作
 
go底层的信号量函数runtime_Semacquire(s *uint32)函数会阻塞goroutine直到信号量s值大于0,然后原子性减这个值,即P操作。
runtime_Semrelease(s *uint32, lifo bool, skipframes int) 函数原子性增加信号量的值,然后通知被阻塞的goroutine, 即V操作

内存对齐

package main

import (
	"fmt"
	"unsafe"
)

type Ins struct {
	x bool  //1
	z byte  // 1
	y int32 // 4
}

type Ins2 struct {
	x bool  //1
	y int32 // 4
	z byte  // 1
}

func main() {
	ins := Ins{}
	fmt.Printf("ins size %d, align: %d", unsafe.Sizeof(ins), unsafe.Alignof(ins))
	fmt.Print("\n")
	ins1 := Ins2{}
	fmt.Printf("ins size %d, align: %d", unsafe.Sizeof(ins1), unsafe.Alignof(ins1))
	fmt.Print("\n")
}
}

// output: size: 8, align: 4
// output: size: 12, align: 4
我们直到CPU的内存读取不是一字节一字节读取的,而是一块一块的,因此在类型的值在内存中对齐的情况下,计算机的加载或者写入会更加高效。
在聚合类型的内存所占长度或许会比它元素所占内存之和更大。编译器会添加未使用的内存地址来填充内存空隙,以确保连续的成员或元素相当于结构体或数组起始地址是对齐的
notion image
所以,我们设计结构体的时候,当结构体成员类型不同时,将相同类型成员定义在相邻位置可以节省内存空间

原子操作CAS

可以用于多线程编程中实现不被打断的数据交换操作,从而避免多线程同时改写某一个数据时由于执行顺序不确定性以及中断的不可预知性产生的数据不一致问题

移位运算 >> 与 <<

  • 左移位运算<< , 按照二进制形式将所有的数字向左移动对应的位数,高位舍弃,低位的空位补零。在数字没有溢出的前提下,左移一位相当于乘以2的一次方,左移n位相当于乘以2的n次方.
  • 右移位运算>>, 按二进制形式把所有的数字向右移动对应位数,低位移出,高位的空位补符号位。右移一位相当于除2,右移n位相当于除以2的n次方。这里是取商,余数就不要了。

unsafe.Pointr指针与uintptr

go的指针分成三类:
  • 普通指针*T 用于传递对象地址,不能进行指针计算
  • unsafe.Pointer 指针: 通用型指针, 任何一个普通类型的指针*T都可以转换成为Pointer指针,不能读取内存中的值(转换为某一具体类型普通指针才行)
  • uintptr: 是一个大小并不明确的无符号整型; 可以与unsafe.Pointer互相转换,可以通过该数值进行指针运算.
notion image
 

源码分析(v1.17)

结构体

包含一个noCopy的辅助字段,和一个具有复合意义的state1字段
type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
其中noCopy是空结构体,不会占用内存,编译器也不会对其进行字节填充。主要是为了通过go vet工具来做静态编译检查,防止开发者在使用WaitGroup过程中对其进行复制,造成安全隐患。
state1字段是一个长度为3的uint32数组,用于表示三部分内容 :
  • 通过Add()设置的goroutine的计数值counter;
  • 通过Wait()陷入堵塞的waiter数
  • 信号量semap
notion image
这边如何奇怪的设定,涉及两个前提:
  • 在真实逻辑中,counter和waiter是被合在一起,当成一个64位额整数对外使用。当变化counter和waiter的值是,可以通过atomic来原子操作这个64位整数.
  • 在32位系统下,使用atomic对64为变量进行原子操作,调用者需要自行保证变量64位对齐,否则会出现异常。
 
内存对齐实现
  • 当state1变量是64位(8byte)对齐时,数组前两位作为64位整数时自然也是64位对齐
  • 当state1变量是32位(4byte)对齐时,我们把数组第一位作对齐padding, 因此state1本身是uint32数组,所以数组第一位也是32位, 将数组后两位看做统一的64位整数时64位对齐
 

Add函数

Add()函数的入参是一个整型,它可正可负,是对counter数值的更改。如果counter数值变为0,那么所有阻塞在Wait()函数的waiter将会被唤醒;如果counter数值为负值,将引起panic
// 除去竞态检测的代码
func (wg *WaitGroup) Add(delta int) {
// 获取包含counter与waiter的复合状态statep,表示信号量值的semap
	statep, semap := wg.state()

 // 赋值逻辑
	state := atomic.AddUint64(statep, uint64(delta)<<32)  //新增counter数值delta
	v := int32(state >> 32)  //获取counter值
	w := uint32(state)   //获取waiter值

 // case1: 这是counter数不能为负
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
 // case2: misuse 引起panic
 // 因为wg其实可以复用的,但是下一次复用的基础是需要将所有状态重置为0
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
  // case3: 本次Add操作只负责增加counter值,直接返回
 // 如果counter大于0,唤醒操作留给之后Add的调用者
 // 如果waiter为0,代表此时还没有阻塞的waiter
	if v > 0 || w == 0 {
		return
	}

 // case4: misuse 引起的panic
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
  // 如果执行到这,一定是 counter=0,waiter>0
  // 能执行到这,一定是执行了Add(-x)的goroutine
  // 它的执行,代表所有子goroutine已经完成了任务
  // 因此,我们需要将复合状态全部归0,并释放掉waiter个数的信号量
	*statep = 0
	for ; w != 0; w-- {
// 释放信号量,执行一次就将唤醒一个阻塞的waiter
		runtime_Semrelease(semap, false, 0)
	}
}
此时statep是一个uint64数值,如果此时statep中包含的counter数为2, waiter为1,输入delta为1, 那么赋值逻辑如下图
notion image

Done函数

Done()函数比较简单,就是调用Add(-1), 实际使用中,当子goroutine任务完成之后,应该调用Done()函数
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait函数

如果WaitGroup中的counter值大于0,那么执行Wait()函数的主goroutine会将waiter值加1,并阻塞等待该值为0,才能继续执行后续代码.
//去除竟态检测代码后
func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()

	for {
		state := atomic.LoadUint64(statep)  //原子复合状态statep
		v := int32(state >> 32)  //获取counter值
		w := uint32(state)   //获取waiters值

   // 如果此时v==0, 证明已经没有待执行任务的子goroutine,直接退出
		if v == 0 {
			// Counter is 0, no need to wait.
			return
		}

    // 如果执行CAS原子操作和读取符合状态之间,没有其他goroutine更改符合状态
    // 那么就将waiter值+1; 否则进入下一轮循环,重新读取复合状态
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
    
	    // 对waiter累加成功后
      // 等待Add 函数中调用runtime_Semrelease唤醒自己
			runtime_Semacquire(semap)

     // reused 引发的panic
     // 当前goroutine被唤醒时,由于唤醒自己的goroutine通过调用Add方法
     // 通过 *statep = 0 语句做了重置操作
    // 此时复合状态位不为0, 就是因为还未等Waiter执行完wait,waitgroup已经发生了复用
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}
 

总结流程

WaitGroup的实现思路还是比较简单的,通过结构体字段state1维护两个计数器和一个信号量,计数器分别是通过Add()添加子goroutine的计数值counter,通过Wait()陷入阻塞的waiter数,信号量用于阻塞与唤醒Waiter。
当执行Add(positive n)时, counter+=n 表明新增了n个goroutine执行任务.每个子goroutine完成任务之后,需要调用Done()将counter数减1, 最后一个子goroutine完成时,counter值也会是0,此时就需要唤醒阻塞在Wait()调用中的Waiter
 
无锁实现
  1. 本身counterwaiter改变时需要保证并发安全。对于这种场景,可以使用一个Mutex或者RWMutex锁,在进行读写时加锁即可,但是这样会有额外的性能开销
  1. WaitGroup直接把counter和waiter看成统一的64位变量(而非拆成两个独立变量), 其中counter是高32位,waiter是底32位。在改变时通过累加值左移32位的方式: atomic.AddUint64(statep, uint64(delta)<<32),实现了counter+=delta的效果
  1. 在Wait函数中, 通过CAS操作atomic.CompareAndSwapUint64(statep, state, state+1),来对waiter进行自增操作,如果返回false,说明state变量有修改,可能counter发生了变化,
  1. WaitGroup本身是可以复用的,因此在Wait结束时候,需要将waiter—,重置状态。但会涉及一次原子变量操作,若是Wait的goroutine比较多,那么这个原子操作也会随之进行很多次。但是WaitGroup这边直接在Done时,当counter等于0时,直接将counter+waiter整个64位全部置0,也可以达到重置状态的效果,避免进行多次原子操作。

注意事项

  • 通过Add()函数添加的counter数一定要与后续通过Done()减去的数值一致,如果前者大,那么阻塞在Wait()调用处的goroutine就得不到唤醒;如果后者大,将会引起panic
  • Add()的增量函数应该最先得到执行
  • 不要对WaitGroup对象进行复制使用
  • 如果要复用WaitGroup, 则必须在所有先前的Wait()调用返回之后在进行新的Add()
 

使用示例

工作池
package main

import (
	"fmt"
	"math/rand"
	"sync"
	"time"
)

func toilet(i int, ch chan int, wg *sync.WaitGroup) {
	defer wg.Done()
	fmt.Println(i, "进入厕所坑位")
	rand.Seed(int64(i))
	t := rand.Intn(10)
	time.Sleep(time.Duration(t) * time.Second)
	fmt.Println(<-ch, "离开厕所坑位")
}

func main() {

	ch := make(chan int, 3) // 三个坑位
	wg := sync.WaitGroup{}
	wg.Add(10) // 10 个cap
	for i := 0; i < 10; i++ {
		ch <- i
		go toilet(i, ch, &wg)
	}
	wg.Wait()
交替打印: 三个函数,分别打印 cat,dog,fish
func output1() {

	cat, dog, fish := make(chan bool), make(chan bool), make(chan bool)
	wg := sync.WaitGroup{}
	loopCount := 20

	wg.Add(3)

	go func(wg *sync.WaitGroup) {
		defer wg.Done()
		i := 0

		for {
			select {
			case <-cat:
				if i >= loopCount {
					fmt.Println("cat over")
					dog <- true
					return
				}
				fmt.Println("cat")
				i++
				dog <- true
				break
			default:
				break
			}
		}
	}(&wg)

	go func(wg *sync.WaitGroup) {
		defer wg.Done()
		i := 0
		for {
			select {
			case <-dog:
				if i >= loopCount {
					fmt.Println("dog over")
					fish <- true
					return
				}
				fmt.Println("dog")
				i++
				fish <- true
				break //结束退出循环
			default:
				break //进入下一个循环
			}
		}

	}(&wg)

	go func(wg *sync.WaitGroup) {
		defer wg.Done()
		i := 0
		for {
			select {
			case <-fish:
				if i >= loopCount {
					fmt.Println("fish over")
					return
				}
				fmt.Println("fish")
				i++
				cat <- true
				break
			default:
				break
			}
		}

	}(&wg)

	cat <- true
	wg.Wait()

}
 

Loading Comments...