Python 元编程

什么是元编程?

元编程(Metaprogramming)是编写操作代码的代码的编程范式,简单说就是“编写能生成或修改代码的代码”。下面将介绍在 Python 中如何进行元编程。


装饰器

装饰器本身是可调用对象,其返回值也是可调用对象。装饰器分为不带参数的装饰器和带参数的装饰器。

不带参数的装饰器

不带参数的装饰器接受被装饰的可调用对象作为参数。

示例 1 - 打印执行时间:

from typing import Callable, Any
from functools import wraps
import time
import random
import logging

LOGGER = logging.getLogger(__name__)


def time_elapsed(f: Callable[..., Any]) -> Callable[..., Any]:
"""    打印被包装的可调用对象的执行时间    """@wraps(f)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        start: float = time.time()
        try:
            return f(*args, **kwargs)
        finally:
            end: float = time.time()
            LOGGER.debug(f"{end - start:.2f}s elapsed while invoking {f.__name__}(args={args}, kwargs={kwargs})")

    return wrapper


@time_elapsed
def test_time_elapsed() -> None:
"""测试函数"""time.sleep(random.randrange(1, 4))


# 等价于 test_time_elapsed = time_elapsed(test_time_elapsed)

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    test_time_elapsed()

带参数的装饰器

带参数的装饰器返回的可调用对象接受被装饰的可调用对象为参数,返回可调用对象。

示例 2 - 在发生特定的异常时重试:

from typing import Callable, Any, Type
from functools import wraps
import time
import logging

LOGGER = logging.getLogger(__name__)


def retry(
        max_retries: int = 5,
        init_delay_ms: int = 20,
        max_delay_ms: int = 100,
        *exceptions: Type[BaseException]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""    在发生特定的异常时,进行重试。    :param max_retries: 最大重试次数    :param init_delay_ms: 初始等待时间,单位是毫秒    :param max_delay_ms: 最大等待时间,单位是毫秒    :param exceptions: 仅在发生指定的异常时重试    """def _inner(decorated: Callable[..., Any]) -> Callable[..., Any]:
"""        :param decorated: 被装饰的可调用对象        """@wraps(decorated)
        def _inner_most(*args: Any, **kwargs: Any) -> Any:
            retries: int = 0
            while True:
                try:
                    return decorated(*args, **kwargs)
                except exceptions:
                    if retries >= max_retries:
                        raise
                    delay: int = min(max_delay_ms, init_delay_ms * (2 ** retries))
                    retries += 1
                    LOGGER.info(f"retry count {retries}, delay {delay}")
                    time.sleep(delay / 1000.)

        return _inner_most

    return _inner


_cnt: int = 0


@retry(2, 20, 100, TypeError, ValueError)
def test_retry() -> None:
    global _cnt
    match _cnt:
        case 0:
            _cnt += 1
            raise TypeError("")
        case 1:
            _cnt += 1
            raise ValueError("")
        case _:
            print("succeeded")


# 等价于 test_retry = retry(2, 20, 100, TypeError, ValueError)(test_retry)

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    test_retry()

描述符

描述符是实现 __get____set____delete__ 方法的类属性。

示例 3 - 只读属性:

from typing_extensions import Self, Any, Type


class ReadonlyProperty:
    def __init__(self: Self, value: Any) -> None:
        self.value = value

    def __get__(self: Self, instance: Any, owner: Type) -> Any:
        return self.value

    def __set__(self: Self, instance: Any, value: Any) -> None:
        raise RuntimeError("this is a readonly property")

    def __delete__(self: Self, instance: Any) -> None:
        raise RuntimeError("this is a readonly property")


class MyClass:
    attr = ReadonlyProperty("attr")


if __name__ == "__main__":
    obj = MyClass()
    # 触发 __get__
    print(obj.attr)
    # 触发 __set__
    obj.attr = "new value"
    # 触发 __del__
    del obj.attr

元类

创建类的两种方式

第一种是使用 class 关键字:

class AClass:
    def a_method(self):  # noqa
        print("a method")

第二种是使用 type 类,下面使用 type 创建一个与上面的类相同的类:

def a_method(self):
    print("a method")

BClass = type("BClass", (object,), {"a_method": a_method})

type 的三个参数分别是:类名、基类元组、属性字典。

什么是元类?

在 Python 中一切皆对象,也就是说类其实也是对象。类用于创建实例,元类用于创建类。上面讲的 type 是所有类的最终元类,所有类都是 type 的实例。

示例 4 - 自动为所有方法加装饰器:

from typing_extensions import Callable, Any, Dict, Tuple
import time
from functools import wraps
import types
import logging

LOGGER = logging.getLogger(__name__)


def time_elapsed(f: Callable[..., Any]) -> Callable[..., Any]:
"""    打印被包装的可调用对象的执行时间    """@wraps(f)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        start: float = time.time()
        try:
            return f(*args, **kwargs)
        finally:
            end: float = time.time()
            LOGGER.debug(f"{end - start:.2f}s elapsed while invoking {f.__name__}(args={args}, kwargs={kwargs})")

    return wrapper


class MetaClass(type):
    def __new__(cls, name: str, bases: Tuple, attrs: Dict[str, Any]) -> type:
        for attr_name, attr in attrs.items():  # type: str, Any
            if isinstance(attr, types.FunctionType):
                attrs[attr_name] = time_elapsed(attr)
        return super(MetaClass, cls).__new__(cls, name, bases, attrs)


class MyClass(metaclass=MetaClass):
    def method(self):  # noqa
        LOGGER.info("method in MyClass")


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    # 类是元类的对象
    print(isinstance(MyClass, MetaClass))  # True
    # 类的实例是类的对象
    print(isinstance(MyClass(), MyClass))  # True
    MyClass().method()

compile、exec、eval

eval

函数签名为:

def eval(__source: str | Buffer | CodeType,
         __globals: dict[str, Any] | None = None,
         __locals: Mapping[str, object] | None = None) -> Any

在给定的 globalslocals 上下文中计算给定的源码。

source 必须是表示 Python 表达式的字符串或 compile 返回的代码对象。globals 必须是字典, locals 可以是任意映射,默认为当前的全局命名空间和局部命名空间。如果仅指定 globals,那么 locals 默认与其相同。

比如:

a = eval("b+c", {"b": 1}, {"c": 2})

exec

函数签名为:

def exec(__source: str | Buffer | CodeType,
         __globals: dict[str, Any] | None = None,
         __locals: Mapping[str, object] | None = None) -> None

在给定的 globalslocals 上下文中执行给定的源码。

source 可以是表示一个或多个 Python 语句的字符串,或compile 返回的代码对象。globals 必须是字典, locals 可以是任意映射,默认为当前的全局命名空间和局部命名空间。如果仅指定 globals,那么 locals 默认与其相同。

比如:

exec(
    """
a = b + c
print(a)
    """,
    {"b": 1},
    {"c": 2},
)

compile

函数签名为:

def compile(source: str | Buffer | Module | Expression | Interactive,
            filename: str | Buffer | _PathLike,
            mode: str,
            flags: Literal[0],
            dont_inherit: bool = False,
            optimize: int = -1,
            *,
            _feature_version: int = -1) -> CodeType

将源代码编译为可被 exec()eval() 执行的代码对象。

源代码可以是 Python 模块、语句或表达式。文件名用于运行时错误信息。mode 必须是 exec,用于编译模块;single,用于编译单条(交互式)语句;或 eval,用于编译单个表达式。flags 参数(如果存在)控制哪些 __future__ 语句影响代码的编译。dont_inherit 参数如果为 True,将阻止编译继承调用代码中已生效的 __future__ 语句效果;如果未指定或为 False,这些语句将与显式指定的特性共同影响编译。

比如:

co = compile(
    """
import time

def current_timestamp() -> float:
    return time.time()

    """,
    "<no-file>",
    "exec",
)
exec(co)
print(current_timestamp())

ast

ast(Abstract Syntax Tree,抽象语法树)模块是用于处理语法树的核心工具。它可以将 Python 代码解析为树状结构,方便进行代码分析、转换和生成等操作。

比如:

import ast

code = """
def add(a, b):
    return a + b
"""

# 将代码解析为 AST
tree = ast.parse(code)


# 遍历 AST 节点示例
class FunctionVisitor(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print(f"发现函数定义: {node.name}")
        self.generic_visit(node)


visitor = FunctionVisitor()
visitor.visit(tree)

示例 - ast + compile + exec

import ast

# 创建函数参数节点
args = ast.arguments(
    posonlyargs=[],
    args=[
        ast.arg(
            arg="a",
            annotation=ast.Name(id="int", ctx=ast.Load())  # 添加参数类型注解
        ),
        ast.arg(
            arg="b",
            annotation=ast.Name(id="int", ctx=ast.Load())  # 添加参数类型注解
        )
    ],
    defaults=[],
    kwonlyargs=[],
    kw_defaults=[]
)

# 创建函数体
return_stmt = ast.Return(
    value=ast.BinOp(
        left=ast.Name(id="a", ctx=ast.Load()),
        op=ast.Add(),
        right=ast.Name(id="b", ctx=ast.Load()),
    ),
)

# 构建函数定义节点
function_def = ast.FunctionDef(
    name="add",
    args=args,
    body=[return_stmt],
    decorator_list=[],
    lineno=0,  # 添加行号
    col_offset=0  # 添加列偏移
)

# 生成AST模块
module = ast.Module(body=[function_def], type_ignores=[])
ast.fix_missing_locations(module)

print("通过 ast 生成的代码如下:")
print(ast.unparse(module))

# 编译、执行 AST
code = compile(module, filename="<ast>", mode="exec", )
exec(code)

# 测试生成的函数
print(add(2, 3))  # 输出:5