Go 标准库未提供可重入锁(ReentrantLock)实现。典型使用场景是在递归中使用可重入锁,避免死锁。
同时,sync.Cond
在 Wait()
时不支持设置超时。下面是 Wait()
方法的注释:
// Wait atomically unlocks c.L and suspends execution
// of the calling goroutine. After later resuming execution,
// Wait locks c.L before returning. Unlike in other systems,
// Wait cannot return unless awoken by Broadcast or Signal.
//
// Because c.L is not locked while Wait is waiting, the caller
// typically cannot assume that the condition is true when
// Wait returns. Instead, the caller should Wait in a loop:
//
// c.L.Lock()
// for !condition() {
// c.Wait()
// }
// ... make use of condition ...
// c.L.Unlock()
与在其它系统中不同,除非被 Broadcast
或 Signal
唤醒,否则,Wait
无法返回。也就是,Wait
不支持设置等待的超时时间。
下面将逐步实现可重入锁和支持设置等待超时时间的条件变量。
较新版本的 Go 未直接提供获取 Goroutine ID 的 API,但是可以间接地通过 runtime.Stack()
获取。以下代码参考自 Go 的标准库:
package sync2
import (
"bytes"
"fmt"
"runtime"
"strconv"
"sync"
)
// 因为 Goroutine ID 从 1 开始,所以 none 可以表示无效的 Goroutine ID
const none uint64 = 0
var goroutineSpace = []byte("goroutine ")
var littleBuf = sync.Pool{
New: func() any {
bytes := make([]byte, 64)
return &bytes
},
}
// GetCurrentGoroutineID 获取当前 Goroutine 的 ID
func GetCurrentGoroutineID() uint64 {
bp := littleBuf.Get().(*[]byte)
defer littleBuf.Put(bp)
b := *bp
b = b[:runtime.Stack(b, false)]
b = bytes.TrimPrefix(b, goroutineSpace)
i := bytes.IndexByte(b, ' ')
if i < 0 {
panic(fmt.Sprintf("No space found in %q", b))
}
b = b[:i]
n, err := strconv.ParseUint(string(b), 10, 64)
if err != nil {
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
}
return n
}
测试:
package sync2
import (
"context"
"math/rand"
"testing"
"time"
)
func TestGetCurrentGoroutineID(t *testing.T) {
gid := GetCurrentGoroutineID()
t.Logf("current goroutine id %d", gid)
var gidInGoroutine uint64
ctx, cancel := context.WithDeadline(
context.Background(),
time.Now().Add(time.Duration(rand.Int()%100)*time.Millisecond),
)
go func(cancel context.CancelFunc) {
gidInGoroutine = GetCurrentGoroutineID()
cancel()
}(cancel)
select {
case <-ctx.Done():
t.Logf("goroutine id %d", gidInGoroutine)
}
}
package sync2
import "sync"
// ReentrantLocker 表示可重入锁。它在 Locker 的基础上,新增三个与可重入相关的方法。
// 对于可重入锁而言,Unlock/Lock 不一定获取/释放底层锁,而可能只修改获取次数,因此引入 UnlockSave/LockRestore “真正地”释放/获取底层锁
type ReentrantLocker interface {
sync.Locker
// UnlockSave 释放底层锁,并且保存内部状态
UnlockSave() any
// LockRestore 获取底层锁,并且恢复内部状态
LockRestore(any)
// IsOwned 用于判断当前 Goroutine 是否持有该锁
IsOwned() bool
}
package sync2
import (
"sync"
"sync/atomic"
)
// ReentrantMutex 是可重入锁的一种实现
type ReentrantMutex struct {
// 持有锁的 Goroutine ID,默认值为 none,表示该锁未被任何 Goroutine 持有
heldBy uint64
// 持有锁的 Goroutine 每次获取锁时,该计数增加 1;每次释放锁时,该计数减少 1。
// 当获取次数大于释放次数,即 heldCount 大于 0 时,其它 Goroutine 无法获取锁;
// 当获取次数等于释放次数,即 heldCount 等于 0 时,锁被“真正地”释放;
// 当获取锁次数小于释放次数时,将引发 Panic
heldCount uint64
// 底层锁
mtx *sync.Mutex
}
// NewReentrantMutex 构建 ReentrantMutex 实例
func NewReentrantMutex() *ReentrantMutex {
var mtx sync.Mutex
return &ReentrantMutex{
heldBy:none,
heldCount: 0,
mtx: &mtx,
}
}
// IsOwned 返回当前 Goroutine 是否持有锁
func (r *ReentrantMutex) IsOwned() bool {
return atomic.LoadUint64(&r.heldBy) == GetCurrentGoroutineID()
}
// tryOwnerAcquire 在锁被当前 Goroutine 持有时,递增持有次数,并且返回 true, 表示当前 Goroutine 已获取锁;
// 否则,返回 false,表示当前 Goroutine 未持有锁,此时锁未被任何 Goroutine 持有,或者被其它 Goroutine 持有
func (r *ReentrantMutex) tryOwnerAcquire() bool {
if r.IsOwned() {
atomic.AddUint64(&r.heldCount, 1)
return true
}
return false
}
// acquireSlow 获取底层锁
func (r *ReentrantMutex) acquireSlow() {
r.mtx.Lock()
atomic.StoreUint64(&r.heldBy, GetCurrentGoroutineID())
atomic.StoreUint64(&r.heldCount, 1)
}
// Lock 获取锁
func (r *ReentrantMutex) Lock() {
if r.tryOwnerAcquire() {
return
}
r.acquireSlow()
}
// Unlock 释放锁
func (r *ReentrantMutex) Unlock() {
if r.IsOwned() {
r.releaseOnce()
return
}
panic("not held by current goroutine")
}
// releaseOnce 释放一次。仅当 heldCount 为 0 时,才“真正地”释放锁,否则只递减持有次数
func (r *ReentrantMutex) releaseOnce() {
// 释放次数大于获取次数
if atomic.LoadUint64(&r.heldCount) == 0 {
panic("unlocks more than locks")
}
// 递减持有次数
heldCount := atomic.AddUint64(&r.heldCount, ^uint64(0))
// 如果持有次数为 0,释放锁
if heldCount == 0 {
atomic.StoreUint64(&r.heldBy, none)
r.mtx.Unlock()
}
}
// LockRestore 获取锁,并且恢复状态。与 UnlockSave 配合使用
func (r *ReentrantMutex) LockRestore(state any) {
// 如果已被当前 Goroutine 持有,那么 Panic
if r.IsOwned() {
panic("already held by current goroutine")
}
// 获取锁
r.acquireSlow()
// 恢复状态
atomic.StoreUint64(&r.heldCount, state.(uint64))
}
// UnlockSave “真正地”释放锁,同时返回状态。与 LockRestore 配合使用
func (r *ReentrantMutex) UnlockSave() any {
// 如果未被当前 Goroutine 持有,那么 Panic
if !r.IsOwned() {
panic("not held by current goroutine")
}
// 获取,并且清空状态
heldCount := atomic.SwapUint64(&r.heldCount, 0)
// 释放锁
atomic.StoreUint64(&r.heldBy, none)
r.mtx.Unlock()
// 返回状态
return heldCount
}
测试:
package sync2
import "testing"
func TestReentrantMutex_Lock(t *testing.T) {
rLock := NewReentrantMutex()
rLock.Lock()
rLock.Lock()
t.Log("acquire twice")
rLock.Unlock()
if !rLock.IsOwned() {
t.Fatal("should own the lock")
}
rLock.Unlock()
if rLock.IsOwned() {
t.Fatal("should not own the lock")
}
}
package sync2
import (
"sync"
"sync/atomic"
"time"
)
// Cond2 与 sync.Cond 相比,在 Wait 时,支持设置超时
type Cond2 struct {
// 底层锁,可能是 sync.Mutex、sync.RWMutex 或者可重入锁
l sync.Locker
// 等待池
waiters []chan struct{}
// 底层锁是否是可重入锁
isReentrantLocker bool
// 持有 Cond2 的 Goroutine ID。仅在 isReentrantLocker 为 false 时使用
heldBy uint64
}
// NewCond2 构建 Cond2 实例
func NewCond2(l sync.Locker) *Cond2 {
isReentrantLocker := false
if _, ok := l.(ReentrantLocker); ok {
isReentrantLocker = true
}
return &Cond2{
l: l,
waiters: make([]chan struct{}, 0),
isReentrantLocker: isReentrantLocker,
heldBy:none,
}
}
// IsOwned 返回当前 Goroutine 是否持有该 Cond2 实例
func (c *Cond2) IsOwned() bool {
if !c.isReentrantLocker {
return atomic.LoadUint64(&c.heldBy) == GetCurrentGoroutineID()
}
return c.l.(ReentrantLocker).IsOwned()
}
// Lock 获取锁
func (c *Cond2) Lock() {
c.l.Lock()
if !c.isReentrantLocker {
atomic.StoreUint64(&c.heldBy, GetCurrentGoroutineID())
}
}
// Unlock 释放锁
func (c *Cond2) Unlock() {
if !c.IsOwned() {
panic("underlying lock not held by current goroutine")
}
if !c.isReentrantLocker {
atomic.StoreUint64(&c.heldBy, none)
}
c.l.Unlock()
}
// Wait 相当于 sync.Cond Wait,但支持设置超时时间。返回 false 表示因超时而结束等待
func (c *Cond2) Wait(timeout time.Duration) bool {
if !c.IsOwned() {
panic("cannot wait on un-locked lock")
}
waiter := make(chan struct{})
c.waiters = append(c.waiters, waiter)
// 释放底层锁
var savedState any
if l, ok := c.l.(ReentrantLocker); ok {
savedState = l.UnlockSave()
} else {
c.Unlock()
}
// 等待被唤醒或超时
gotIt := false
if timeout <= 0 {
select {
case <-waiter:
gotIt = true
}
} else {
select {
case <-time.NewTimer(timeout).C:
case <-waiter:
gotIt = true
}
}
// 获取底层锁
if l, ok := c.l.(ReentrantLocker); ok {
l.LockRestore(savedState)
} else {
c.Lock()
}
select {
case <-waiter:
gotIt = true
default:
}
// 如果未被通知而超时,那么从 waiters 中移除 waiter
if !gotIt {
for idx, _waiter := range c.waiters {
if _waiter == waiter {
c.waiters = append(c.waiters[:idx], c.waiters[idx+1:]...)
break
}
}
}
return gotIt
}
// Notify 通知等待池中的 Goroutine
func (c *Cond2) Notify(n int) {
if !c.IsOwned() {
panic("underlying lock not held by current goroutine")
}
if len(c.waiters) == 0 {
return
}
removed := -1
for idx, waiter := range c.waiters {
if idx < n {
removed++
close(waiter)
} else {
break
}
}
// 从 Waiting 池中移除已通知的 waiter
c.waiters = c.waiters[removed+1:]
}
// NotifyOne 相当于 sync.Cond Signal
func (c *Cond2) NotifyOne() {
c.Notify(1)
}
// NotifyAll 相当于 sync.Cond Broadcast
func (c *Cond2) NotifyAll() {
c.Notify(len(c.waiters))
}
测试:
package sync2
import (
"sync"
"testing"
"time"
)
func TestCond2(t *testing.T) {
c := NewCond2(&sync.Mutex{})
c.Lock()
gotIt := c.Wait(100 * time.Millisecond)
if gotIt {
t.Fatal("should not get it")
}
c.Unlock()
startCh := make(chan struct{})
stopCh := make(chan struct{})
go func() {
c.Lock()
close(startCh)
gotIt = c.Wait(0)
c.Unlock()
close(stopCh)
}()
select {
case <-startCh:
}
time.Sleep(100 * time.Millisecond)
c.Lock()
c.NotifyAll()
c.Unlock()
select {
case <-stopCh:
}
if !gotIt {
t.Fatal("should get it")
}
}