Go 标准库未提供可重入锁(ReentrantLock)实现。典型使用场景是在递归中使用可重入锁,避免死锁。

同时,sync.CondWait() 时不支持设置超时。下面是 Wait() 方法的注释:

// Wait atomically unlocks c.L and suspends execution
// of the calling goroutine. After later resuming execution,
// Wait locks c.L before returning. Unlike in other systems,
// Wait cannot return unless awoken by Broadcast or Signal.
//
// Because c.L is not locked while Wait is waiting, the caller
// typically cannot assume that the condition is true when
// Wait returns. Instead, the caller should Wait in a loop:
//
//  c.L.Lock()
//  for !condition() {
//  c.Wait()
//  }
//  ... make use of condition ...
//  c.L.Unlock()

与在其它系统中不同,除非被 BroadcastSignal 唤醒,否则,Wait 无法返回。也就是,Wait 不支持设置等待的超时时间。

下面将逐步实现可重入锁和支持设置等待超时时间的条件变量。


1. 获取 Goroutine ID

较新版本的 Go 未直接提供获取 Goroutine ID 的 API,但是可以间接地通过 runtime.Stack() 获取。以下代码参考自 Go 的标准库:

package sync2

import (
    "bytes"
    "fmt"
    "runtime"
    "strconv"
    "sync"
)

// 因为 Goroutine ID 从 1 开始,所以 none 可以表示无效的 Goroutine ID
const none uint64 = 0

var goroutineSpace = []byte("goroutine ")

var littleBuf = sync.Pool{
    New: func() any {
        bytes := make([]byte, 64)
        return &bytes
    },
}

// GetCurrentGoroutineID 获取当前 Goroutine 的 ID
func GetCurrentGoroutineID() uint64 {
    bp := littleBuf.Get().(*[]byte)
    defer littleBuf.Put(bp)
    b := *bp
    b = b[:runtime.Stack(b, false)]
    b = bytes.TrimPrefix(b, goroutineSpace)
    i := bytes.IndexByte(b, ' ')
    if i < 0 {
        panic(fmt.Sprintf("No space found in %q", b))
    }
    b = b[:i]
    n, err := strconv.ParseUint(string(b), 10, 64)
    if err != nil {
        panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
    }
    return n
}

测试:

package sync2

import (
    "context"
    "math/rand"
    "testing"
    "time"
)

func TestGetCurrentGoroutineID(t *testing.T) {
    gid := GetCurrentGoroutineID()
    t.Logf("current goroutine id %d", gid)

    var gidInGoroutine uint64
    ctx, cancel := context.WithDeadline(
        context.Background(),
        time.Now().Add(time.Duration(rand.Int()%100)*time.Millisecond),
    )
    go func(cancel context.CancelFunc) {
        gidInGoroutine = GetCurrentGoroutineID()
        cancel()
    }(cancel)
    select {
    case <-ctx.Done():
        t.Logf("goroutine id %d", gidInGoroutine)
    }
}

2. 定义可重入锁接口

package sync2

import "sync"

// ReentrantLocker 表示可重入锁。它在 Locker 的基础上,新增三个与可重入相关的方法。
// 对于可重入锁而言,Unlock/Lock 不一定获取/释放底层锁,而可能只修改获取次数,因此引入 UnlockSave/LockRestore “真正地”释放/获取底层锁
type ReentrantLocker interface {
    sync.Locker
    // UnlockSave 释放底层锁,并且保存内部状态
    UnlockSave() any
    // LockRestore 获取底层锁,并且恢复内部状态
    LockRestore(any)
    // IsOwned 用于判断当前 Goroutine 是否持有该锁
    IsOwned() bool
}

3. 基于 Mutex 实现 ReentrantMutex

package sync2

import (
    "sync"
    "sync/atomic"
)

// ReentrantMutex 是可重入锁的一种实现
type ReentrantMutex struct {
    // 持有锁的 Goroutine ID,默认值为 none,表示该锁未被任何 Goroutine 持有
    heldBy uint64
    // 持有锁的 Goroutine 每次获取锁时,该计数增加 1;每次释放锁时,该计数减少 1。
    // 当获取次数大于释放次数,即 heldCount 大于 0 时,其它 Goroutine 无法获取锁;
    // 当获取次数等于释放次数,即 heldCount 等于 0 时,锁被“真正地”释放;
    // 当获取锁次数小于释放次数时,将引发 Panic
    heldCount uint64
    // 底层锁
    mtx *sync.Mutex
}

// NewReentrantMutex 构建 ReentrantMutex 实例
func NewReentrantMutex() *ReentrantMutex {
    var mtx sync.Mutex
    return &ReentrantMutex{
        heldBy:none,
        heldCount: 0,
        mtx:   &mtx,
    }
}

// IsOwned 返回当前 Goroutine 是否持有锁
func (r *ReentrantMutex) IsOwned() bool {
    return atomic.LoadUint64(&r.heldBy) == GetCurrentGoroutineID()
}

// tryOwnerAcquire 在锁被当前 Goroutine 持有时,递增持有次数,并且返回 true, 表示当前 Goroutine 已获取锁;
// 否则,返回 false,表示当前 Goroutine 未持有锁,此时锁未被任何 Goroutine 持有,或者被其它 Goroutine 持有
func (r *ReentrantMutex) tryOwnerAcquire() bool {
    if r.IsOwned() {
        atomic.AddUint64(&r.heldCount, 1)
        return true
    }
    return false
}

// acquireSlow 获取底层锁
func (r *ReentrantMutex) acquireSlow() {
    r.mtx.Lock()
    atomic.StoreUint64(&r.heldBy, GetCurrentGoroutineID())
    atomic.StoreUint64(&r.heldCount, 1)
}

// Lock 获取锁
func (r *ReentrantMutex) Lock() {
    if r.tryOwnerAcquire() {
        return
    }
    r.acquireSlow()
}

// Unlock 释放锁
func (r *ReentrantMutex) Unlock() {
    if r.IsOwned() {
        r.releaseOnce()
        return
    }
    panic("not held by current goroutine")
}

// releaseOnce 释放一次。仅当 heldCount 为 0 时,才“真正地”释放锁,否则只递减持有次数
func (r *ReentrantMutex) releaseOnce() {
    // 释放次数大于获取次数
    if atomic.LoadUint64(&r.heldCount) == 0 {
        panic("unlocks more than locks")
    }
    // 递减持有次数
    heldCount := atomic.AddUint64(&r.heldCount, ^uint64(0))
    // 如果持有次数为 0,释放锁
    if heldCount == 0 {
        atomic.StoreUint64(&r.heldBy, none)
        r.mtx.Unlock()
    }
}

// LockRestore 获取锁,并且恢复状态。与 UnlockSave 配合使用
func (r *ReentrantMutex) LockRestore(state any) {
    // 如果已被当前 Goroutine 持有,那么 Panic
    if r.IsOwned() {
        panic("already held by current goroutine")
    }
    // 获取锁
    r.acquireSlow()
    // 恢复状态
    atomic.StoreUint64(&r.heldCount, state.(uint64))
}

// UnlockSave “真正地”释放锁,同时返回状态。与 LockRestore 配合使用
func (r *ReentrantMutex) UnlockSave() any {
    // 如果未被当前 Goroutine 持有,那么 Panic
    if !r.IsOwned() {
        panic("not held by current goroutine")
    }
    // 获取,并且清空状态
    heldCount := atomic.SwapUint64(&r.heldCount, 0)
    // 释放锁
    atomic.StoreUint64(&r.heldBy, none)
    r.mtx.Unlock()
    // 返回状态
    return heldCount
}

测试:

package sync2

import "testing"

func TestReentrantMutex_Lock(t *testing.T) {
    rLock := NewReentrantMutex()
    rLock.Lock()
    rLock.Lock()
    t.Log("acquire twice")
    rLock.Unlock()
    if !rLock.IsOwned() {
        t.Fatal("should own the lock")
    }
    rLock.Unlock()
    if rLock.IsOwned() {
        t.Fatal("should not own the lock")
    }
}

4. 实现支持设置超时的条件变量

package sync2

import (
    "sync"
    "sync/atomic"
    "time"
)

// Cond2 与 sync.Cond 相比,在 Wait 时,支持设置超时
type Cond2 struct {
    // 底层锁,可能是 sync.Mutex、sync.RWMutex 或者可重入锁
    l sync.Locker
    // 等待池
    waiters []chan struct{}
    // 底层锁是否是可重入锁
    isReentrantLocker bool
    // 持有 Cond2 的 Goroutine ID。仅在 isReentrantLocker 为 false 时使用
    heldBy uint64
}

// NewCond2 构建 Cond2 实例
func NewCond2(l sync.Locker) *Cond2 {
    isReentrantLocker := false
    if _, ok := l.(ReentrantLocker); ok {
        isReentrantLocker = true
    }
    return &Cond2{
        l: l,
        waiters:   make([]chan struct{}, 0),
        isReentrantLocker: isReentrantLocker,
        heldBy:none,
    }
}

// IsOwned 返回当前 Goroutine 是否持有该 Cond2 实例
func (c *Cond2) IsOwned() bool {
    if !c.isReentrantLocker {
        return atomic.LoadUint64(&c.heldBy) == GetCurrentGoroutineID()
    }
    return c.l.(ReentrantLocker).IsOwned()
}

// Lock 获取锁
func (c *Cond2) Lock() {
    c.l.Lock()
    if !c.isReentrantLocker {
        atomic.StoreUint64(&c.heldBy, GetCurrentGoroutineID())
    }
}

// Unlock 释放锁
func (c *Cond2) Unlock() {
    if !c.IsOwned() {
        panic("underlying lock not held by current goroutine")
    }
    if !c.isReentrantLocker {
        atomic.StoreUint64(&c.heldBy, none)
    }
    c.l.Unlock()
}

// Wait 相当于 sync.Cond Wait,但支持设置超时时间。返回 false 表示因超时而结束等待
func (c *Cond2) Wait(timeout time.Duration) bool {
    if !c.IsOwned() {
        panic("cannot wait on un-locked lock")
    }
    waiter := make(chan struct{})
    c.waiters = append(c.waiters, waiter)

    // 释放底层锁
    var savedState any
    if l, ok := c.l.(ReentrantLocker); ok {
        savedState = l.UnlockSave()
    } else {
        c.Unlock()
    }

    // 等待被唤醒或超时
    gotIt := false
    if timeout <= 0 {
        select {
        case <-waiter:
            gotIt = true
        }
    } else {
        select {
        case <-time.NewTimer(timeout).C:
        case <-waiter:
            gotIt = true
        }
    }

    // 获取底层锁
    if l, ok := c.l.(ReentrantLocker); ok {
        l.LockRestore(savedState)
    } else {
        c.Lock()
    }

    select {
    case <-waiter:
        gotIt = true
    default:
    }
    // 如果未被通知而超时,那么从 waiters 中移除 waiter
    if !gotIt {
        for idx, _waiter := range c.waiters {
            if _waiter == waiter {
                c.waiters = append(c.waiters[:idx], c.waiters[idx+1:]...)
                break
            }
        }
    }

    return gotIt
}

// Notify 通知等待池中的 Goroutine
func (c *Cond2) Notify(n int) {
    if !c.IsOwned() {
        panic("underlying lock not held by current goroutine")
    }
    if len(c.waiters) == 0 {
        return
    }
    removed := -1
    for idx, waiter := range c.waiters {
        if idx < n {
            removed++
            close(waiter)
        } else {
            break
        }
    }
    // 从 Waiting 池中移除已通知的 waiter
    c.waiters = c.waiters[removed+1:]
}

// NotifyOne 相当于 sync.Cond Signal
func (c *Cond2) NotifyOne() {
    c.Notify(1)
}

// NotifyAll 相当于 sync.Cond Broadcast
func (c *Cond2) NotifyAll() {
    c.Notify(len(c.waiters))
}

测试:

package sync2

import (
    "sync"
    "testing"
    "time"
)

func TestCond2(t *testing.T) {
    c := NewCond2(&sync.Mutex{})
    c.Lock()
    gotIt := c.Wait(100 * time.Millisecond)
    if gotIt {
        t.Fatal("should not get it")
    }
    c.Unlock()

    startCh := make(chan struct{})
    stopCh := make(chan struct{})
    go func() {
        c.Lock()
        close(startCh)
        gotIt = c.Wait(0)
        c.Unlock()
        close(stopCh)
    }()
    select {
    case <-startCh:
    }
    time.Sleep(100 * time.Millisecond)
    c.Lock()
    c.NotifyAll()
    c.Unlock()
    select {
    case <-stopCh:
    }
    if !gotIt {
        t.Fatal("should get it")
    }
}