目录


例子[TOC]

[root@iZj6chejzrsqpclb7miryaZ ~]# cat test.py 
# coding: utf8

from threading import current_thread, Thread, local, Lock

class MyThreadLocal(local):
    def __init__(self, name="default"):
        self.name = name

# 在主线程,创建一个线程本地变量,它可以被所有的线程共享
my_local = MyThreadLocal()
lock = Lock()

def log():
    with lock:
        print "[Thread: %s]my_local.name = %s" % (
            current_thread().getName(), my_local.name)

my_local.name = current_thread().getName()
log()

def func():
    my_local.name = current_thread().getName()
    log()

thread = Thread(target=func)

thread.start()
thread.join()

log()
[root@iZj6chejzrsqpclb7miryaZ ~]# python test.py 
[Thread: MainThread]my_local.name = MainThread
[Thread: Thread-1]my_local.name = Thread-1
[Thread: MainThread]my_local.name = MainThread

ThreadLocal变量[TOC]

ThreadLocal变量,也就是线程本地变量,它会为每个使用该变量的线程维护一个变量的副本,在某个线程中,对该变量的修改,只会改变自己的副本,不会影响其他的线程的副本


Python中的实现[TOC]

找到标准库threading的源代码:

[root@iZj6chejzrsqpclb7miryaZ ~]# python -c "import threading; print threading"
<module 'threading' from '/usr/lib64/python2.7/threading.pyc'>

在源代码中,搜索local,可以看到:

1204 # get thread-local implementation, either from the thread
1205 # module, or from the python fallback
1206 
1207 try:
1208     from thread import _local as local
1209 except ImportError:
1210     from _threading_local import local

也就是说,如果Thread模块已经实现了thread-local的话,那么就使用Thread模块的实现,否则使用纯Python实现:_threading_local.local
在Python 2.7.5中,Thread模块已经实现了thread-local,所以是不会使用Python的实现的。但是为了学习,我们看看Python实现的源代码。

[root@iZj6chejzrsqpclb7miryaZ ~]# python -c "import _threading_local as tl; print tl"
<module '_threading_local' from '/usr/lib64/python2.7/_threading_local.pyc'>

因为代码只有一百多行,所以,全部贴进来了,并使用注释对源代码进行解读,阅读之前,建议先过一下这两篇文档:

_threading_local模块的源代码如下:

# 所有 ThreadLocal对象 都是local或其子类的实例,而local继承自_localbase这个类
class _localbase(object):
    # __slots__用来限制对象所能拥有的属性,它只对新式类生效,并且可以看到_localbase就是继承了object的新式类
    __slots__ = '_local__key', '_local__args', '_local__lock'

    # 构造方法
    def __new__(cls, *args, **kw):
        # 创建ThreadLocal对象
        self = object.__new__(cls)

        # key是与对象自身的id相关的,所以具有唯一性
        key = '_local__key', 'thread.local.' + str(id(self))
        object.__setattr__(self, '_local__key', key)

        # 将传递给构造方法的参数保存起来,为其他线程创建副本的时候,会使用这些参数再次执行__init__方法
        object.__setattr__(self, '_local__args', (args, kw))

        # 用于线程间同步的锁对象
        object.__setattr__(self, '_local__lock', RLock())

        if (args or kw) and (cls.__init__ is object.__init__):
            raise TypeError("Initialization arguments are not supported")

        # We need to create the thread dict in anticipation of
        # __init__ being called, to make sure we don't call it
        # again ourselves.

        # 让当前线程对象的 key属性 指向 ThreadLocal对象的属性字典
        # + 在__new__执行完之后,会自动的调用__init__,
        # + 对对象的属性进行初始化,而这些属性会被放到对象的属性字典中
        dict = object.__getattribute__(self, '__dict__')
        current_thread().__dict__[key] = dict

        return self
# 总结:在创建ThreadLocal对象的时候,会给它增加三个属性:用来唯一标识对象的key;传递给构造方法的参数;用于线程间同步的锁。
# + 然后让当前线程对象的key属性 指向 ThreadLocal对象的属性字典。

def _patch(self):
    # 获取唯一标识对象的key
    key = object.__getattribute__(self, '_local__key')
    # 获取当前线程对象的key属性所指向的属性字典
    d = current_thread().__dict__.get(key)

    # 如果d是None,则表示当前线程是第一次访问该ThreadLocal对象,
	# + 此时,为它创建一个属于它的副本
    if d is None:
        # 1,生成一个新的空字典,然后让对象的属性字典 和 当前线程对象的key属性都指向它
        d = {}
        current_thread().__dict__[key] = d
        object.__setattr__(self, '__dict__', d)


        # 2,使用初始化时的参数,再次调用__init__方法。__init__方法会向该属性字典填充属性
        # we have a new instance dict, so call out __init__ if we have
        # one
        cls = type(self)
        if cls.__init__ is not object.__init__:
            args, kw = object.__getattribute__(self, '_local__args')
            cls.__init__(self, *args, **kw)
        
        # 线程对象 的 副本 初始化完成

    # 如果d不是None,则将当前线程的副本,设置为ThreadLocal对象的属性字典
    else:
        object.__setattr__(self, '__dict__', d)
# 总结:_patch的作用就是,如果当前线程的副本还没被创建,则通过重新执行__init__方法为线程创建一个,
# + 最后,将线程的副本,设置成ThreadLocal对象的属性字典,之所以这么做,
# + 是因为后续的操作(比如获取属性、设置属性、删除属性)都是 通过对ThreadLocal对象的属性的操作 来完成的

class local(_localbase):

    def __getattribute__(self, name):
        lock = object.__getattribute__(self, '_local__lock')
        lock.acquire()
        try:
            _patch(self)
            return object.__getattribute__(self, name)
        finally:
            lock.release()

    def __setattr__(self, name, value):
        if name == '__dict__':
            raise AttributeError(
                "%r object attribute '__dict__' is read-only"
                % self.__class__.__name__)
        lock = object.__getattribute__(self, '_local__lock')
        lock.acquire()
        try:
            _patch(self)
            return object.__setattr__(self, name, value)
        finally:
            lock.release()

    def __delattr__(self, name):
        if name == '__dict__':
            raise AttributeError(
                "%r object attribute '__dict__' is read-only"
                % self.__class__.__name__)
        lock = object.__getattribute__(self, '_local__lock')
        lock.acquire()
        try:
            _patch(self)
            return object.__delattr__(self, name)
        finally:
            lock.release()

    # 析构方法
    def __del__(self):
        import threading

        # 获取唯一标识对象的key
        key = object.__getattribute__(self, '_local__key')

        # 获取所有启动,但是没有停止的线程
        try:
            # We use the non-locking API since we might already hold the lock
            # (__del__ can be called at any point by the cyclic GC).
            threads = threading._enumerate()
        except:
            # If enumerating the current threads fails, as it seems to do
            # during shutdown, we'll skip cleanup under the assumption
            # that there is nothing to clean up.
            return

        # 对每一个线程,删除它的属性字典中的值为key的属性,
        # + 也就是,释放掉线程的副本
        for thread in threads:
            try:
                __dict__ = thread.__dict__
            except AttributeError:
                # Thread is dying, rest in peace.
                continue

            if key in __dict__:
                try:
                    del __dict__[key]
                except KeyError:
                    pass # didn't have anything in this thread

from threading import current_thread, RLock

一些思考[TOC]

[root@iZj6chejzrsqpclb7miryaZ ~]# cat my_thread_local.py
# coding: utf8

from threading import current_thread
from weakref import proxy
import functools

# 因为dict不支持weakref,所以自定义一个map类型
class _MyMap(object):
    def __init__(self, map=None, **kw):
        self._map = {}
        if map:
            self._map.update(map)
        self._map.update(kw)

    def __setitem__(self, k, v):
        self._map[k] = v

    def __getitem__(self, k):
        if k not in self._map:
            raise AttributeError("no attribute: %s" % str(k))
        return self._map[k]

    def __delattr__(self, k):
        self._map.pop(k, None)

def _callback(my_map, map, key):
    print "[map=%s, key=%s]" % (map, key)
    map.pop(key, None)

def _patch(self):
    thread = current_thread()
    ident = thread.ident
    map = object.__getattribute__(self, "_local__map")

    if not map.get(ident):
        my_map = _MyMap(object.__getattribute__(self, "_local__kwargs"))
        map[ident] = proxy(my_map,
            functools.partial(_callback, map=map, key=ident))
        thread.__dict__[
                object.__getattribute__(self, "_local__key")
            ] = my_map
    return map[ident]

class local(object):
    __slots__ = "_local__key", "_local__kwargs", "_local__map"

    def __new__(cls, *a, **kw):
        self = object.__new__(cls)

        key = cls.__name__, \
            str(id(self))
        object.__setattr__(self, "_local__key", key)
        object.__setattr__(self, "_local__kwargs", kw)
        object.__setattr__(self, "_local__map", {})

        return self

    def __setattr__(self, name, value):
        my_map = _patch(self)
        my_map[name] = value

    def __getattribute__(self, name):
        my_map = _patch(self)
        return my_map[name]

    def __delattr__(self, name):
        my_map = _patch(self)
        delattr(my_map, name)

[root@iZj6chejzrsqpclb7miryaZ ~]# cat test.py 
# coding: utf8

from threading import current_thread, Thread, Lock

from my_thread_local import local 

class MyThreadLocal(local):
    def __init__(self, name="default"):
        self.name = name

# 在主线程,创建一个线程本地变量,它可以被所有的线程共享
my_local = MyThreadLocal()
lock = Lock()

def log():
    with lock:
        print "[Thread: %s]my_local.name = %s" % (
            current_thread().getName(), my_local.name)

my_local.name = current_thread().getName()
log()

def func():
    my_local.name = current_thread().getName()
    log()

thread = Thread(target=func)

thread.start()
thread.join()

log()
[root@iZj6chejzrsqpclb7miryaZ ~]# python test.py 
[Thread: MainThread]my_local.name = MainThread
[Thread: Thread-1]my_local.name = Thread-1
[Thread: MainThread]my_local.name = MainThread
[map={139912237799232: <weakproxy at 0x7f3fdb26f890 to _MyMap at 0x7f3fdb264b90>, 139912099251968: <weakproxy at 0x7f3fdb26f940 to NoneType at 0x7f3fdb172f20>}, key=139912099251968]
[map={139912237799232: <weakproxy at 0x7f3fdb26f890 to NoneType at 0x7f3fdb172f20>}, key=139912237799232]