目录


概述

屏障(Barrier)会阻塞所有在该屏障上等待(wait)的线程,一直到指定数量的线程进入了该屏障,屏障才会被解除,这些线程才能继续运行。也就是说,屏障用于在同步点上,同步确定数量的线程
本文是在Python3.7测试通过的。Barrier是在Python3.2引入到threading模块的。
Python中的Barrier其实是一个“循环屏障”,可以查看这个例子


例子


from threading import Thread, Barrier
from time import sleep
import logging

logging.basicConfig(level=logging.INFO,
    format="[thread:%(threadName)s] [%(asctime)s] %(message)s",
    datefmt="%F %T")

logger = logging.getLogger(__name__)

b = Barrier(2)

def func(t, b):
    sleep(t)
    logger.info("enter into barrier")
    b.wait()
    logger.info("exit barrier")

t1 = Thread(target=func, name="t1", args=(3, b))
t2 = Thread(target=func, name="t2", args=(5, b))
t1.start()
t2.start()

logger.info("start")
t1.join()
t2.join()
logger.info("end")


源码解析

源代码在:https://github.com/python/cpython/blob/3.7/Lib/threading.py
阅读本部分之前,需要对条件变量有了解,否则,请先阅读:Python Condition源码解析
Barrier对象内部会维护一个状态:self._state。其取值为:

接下来,用一张状态迁移图,来描述Barrier对象的内部状态的部分相互转化关系:
barrier.jpg

Python3的Condition增加了wait_for方法,并且在Barrier类中用到了该方法,因此先看一下,它的作用:


    def wait_for(self, predicate, timeout=None):
        # predicate是一个可调用对象,调用该可调用对象,需要返回一个布尔值

        # 当timeout是None的时候,wait_for方法,首先会调用
        #  + predicate,如果返回True,则wait_for返回True;
        #  + 否则,该方法会阻塞,直到被唤醒。
        #  + 被唤醒之后,会继续调用predicate,...,如此循环,直到调用predicate返回True

        # 当timeout不是None的时候,wait_for方法,首先会调用predicate,
        #  + 如果返回True,则wait_for返回True;
        #  + 否则,该方法会阻塞,直到超时或被唤醒。
        #  + 在超时或被唤醒之后,会继续调用predicate,...,如此循环,直到:
        #  + + 1,在timeout期间内,调用predicate返回了True,那么wait_for返回True
        #  + + 2,在timeout期间内,调用predicate一直返回False,
        #  + + + 那么,wait_for会等待到超时,然后返回False

        # 也就是说,如果该方法返回了True,那么调用predicate就会返回True;否则,返回False

        endtime = None
        waittime = timeout
        result = predicate()
        while not result:
            if waittime is not None:
                if endtime is None:
                    endtime = _time() + waittime
                else:
                    waittime = endtime - _time()
                    if waittime <= 0:
                        break
            self.wait(waittime)
            result = predicate()
        return result

下面是Barrier类的主要源码:


class Barrier:
    def __init__(self, parties, action=None, timeout=None):
        # parties: 线程数量
        # action: 是一个可调用对象,当屏障解除时,会调用该可调用对象,
        # + 如果,在调用该可调用对象时,出现了异常,该屏障会进入到broken状态(可以查看:例子)
        # timeout:是一个超时时间,它是wait()方法的默认超时时间

        self._cond = Condition(Lock())
        self._action = action
        self._timeout = timeout
        self._parties = parties
        # 关于屏障对象的状态,上面已经详细说明过
        self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken
        # 已经进入到屏障的线程的数量
        self._count = 0

    def wait(self, timeout=None):
        """等候屏障被解除或被打破。
        只有当指定数量的线程进入屏障时,屏障才会被移除,这些线程才能继续运行。
        如果指定了action,那么在屏障被移除之前,会先调用它。
        该方法返回,线程是第几个进入到屏障的(从0开始计数)
        """

        # 如果没有指定timeout参数,那么使用构造方法中指定的timeout
        if timeout is None:
            timeout = self._timeout

        with self._cond:
            # 当屏障正在解除,或者正在重置时,会阻塞尝试进入到屏障的线程
            # + 在屏障完成解除或重置时,会唤醒这些线程,它们会再次尝试进入屏障
            # + 在屏障被打破时,也会唤醒这些线程,它们会获得BrokenBarrierError,
            self._enter()

            index = self._count
            self._count += 1
            try:
                # 当已经有指定数量的线程,进入到屏障时,那么解除屏障:
                # + 1,运行action指定的可调用对象
                # + 2,将屏障置为解除状态
                # + 3,释放所有在屏障上等待的线程
                if index + 1 == self._parties:
                    self._release()
                # 否则,等待:
                # + 1,直到屏障被解除
                # + 2,直到屏障被重置
                # + 3,直到屏障被打破
                # + 4,达到超时时间,
                # + + 在达到超时时间之后,线程会在内部打破屏障
                else:
                    self._wait(timeout)

            # 在退出屏障时,
            # + 1,返回线程是第几个进入到屏障的(从0开始计数)
            # + 2,递减进入到屏障的线程数
            # + 3,最后一个退出屏障的线程会重置屏障的状态,
            # + + 并唤醒尝试进入屏障的线程
                return index
            finally:
                self._count -= 1
                self._exit()

    def _enter(self):
        while self._state in (-1, 1):
            self._cond.wait()
        if self._state < 0:
            raise BrokenBarrierError
        assert self._state == 0

    def _release(self):
        try:
            if self._action:
                self._action()
            self._state = 1
            self._cond.notify_all()
        except:
            self._break()
            raise

    def _wait(self, timeout):
        if not self._cond.wait_for(lambda : self._state != 0, timeout):
            self._break()
            raise BrokenBarrierError
        if self._state < 0:
            raise BrokenBarrierError
        assert self._state == 1

    def _exit(self):
        if self._count == 0:
            if self._state in (-1, 1):
                self._state = 0
                self._cond.notify_all()

    def reset(self):
        """把屏障重置到初始状态
        当前正在屏障上等待的线程会得到BrokenBarrierError,并退出屏障
        """

        # 1,如果屏障处于解除或重置状态,那么在最后一个等待的线程离开时,
        # + 它会将屏障重置为初始状态;因此,这两种状态无需任何处理

        # 2,否则,当屏障处于填充状态 或 打破状态时,
        # 2.1,如果有线程在等待,
        # + 那么,将屏障置为重置状态,一方面,它会阻塞尝试进入到屏障的线程;
        # + 另外一方面,当最后一个等待的线程离开时,它会将屏障置为初始状态
        # 2.2,如果没有线程在等待,
        # + 那么,该方法自己将屏障置为填充状态(也就是初始的状态),
        # + 然后,唤醒尝试进入到屏障的线程,让它们再次尝试进入屏障

        with self._cond:
            if self._count > 0:
                if self._state == 0:
                    self._state = -1
                elif self._state == -2:
                    self._state = -1
            else:
                self._state = 0
            self._cond.notify_all()

    def abort(self):
        """从外部打破屏障"""
        with self._cond:
            self._break()

    def _break(self):
        self._state = -2
        self._cond.notify_all()

    @property
    def parties(self):
        # 返回需要的线程数量
        return self._parties

    @property
    def n_waiting(self):
       # 返回当前正在屏障上等待的线程的数量
        if self._state == 0:
            return self._count
        return 0

    @property
    def broken(self):
        # 返回屏障是否处于broken状态
        return self._state == -2


class BrokenBarrierError(RuntimeError):
    pass