点击任意处关闭

简化 Python 调用 C 库函数的方法


前言

有时候我们需要在 Python 中调用 C 语言编写的函数。常见的写法,比如:

from ctypes import *

dll = CDLL('my_dll.dll')
add = dll.add
add.argtypes = [c_int, c_int] # arguments types
add.restype = c_int # return type
print(add(1, 2)) # 3

非常简洁,不过我个人觉得还是不够优雅。我最喜欢的还是 C# 的写法:

c#
using System; using System.Runtime.InteropServices; class Example { // Use DllImport to import the Win32 MessageBox function. [DllImport("user32.dll", CharSet = CharSet.Unicode)] public static extern int MessageBox(IntPtr hWnd, String text, String caption, uint type); static void Main() { // Call the MessageBox function using platform invoke. MessageBox(new IntPtr(0), "Hello World!", "Hello Dialog", 0); } }

我在 Python 中简单实现了类似的写法。在讲实现方法前,先来看看实际效果吧~

效果展示

以计算斐波那契数列为例:

stdio
from extern_c import * from ctypes import * @extern_func('msvcrt') def printf(fmt: c_char_p, *args) -> c_int: pass @extern_func('msvcrt') def scanf(restrict_format: c_char_p, *args) -> c_int: pass
Fibonacci Sequence
from stdio import * # 就是上面的那段代码 from ctypes import * a = c_int(0) b = c_int(1) m = c_int(0) printf(b'Input a positive integer: ') if scanf(b'%d', byref(m)) != 1 or m <= 0: printf(b'Invalid input!') exit(-1) while a.value <= m.value: printf(b'Fibonacci sequence: %d' if a.value == 0 else b', %d', a) a, b = b, c_int(a.value + b.value)

控制台输入/输出如下:

Console
load extern function: printf, callback: None load extern function: scanf, callback: None Input a positive integer: 10000 Fibonacci sequence: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765

是不是很优雅呢~ 这种写法充分利用了 Python 的 Type Hints,和定义普通函数基本没有区别。

目前 extern_func 支持的参数类型有:

  • positional-only
  • positional or keyword
  • variadic positional

此外,它还支持一种写法:

@extern_func('msvcrt')
def scanf(restrict_format: c_char_p, *args) -> c_int:
    def callback(v):
        if v <= 0:
            raise EOFError()
        return v
    return callback

scanf 中的 callback 函数会在每次调用完 scanf 后被调用。callback 的参数 v 是 C 中 scanf 的返回值。callback 的返回值会最终被返回给调用者。

实现

知识储备

  1. Python 的装饰器写法
  2. ctype 的简单使用

代码

from functools import wraps
from inspect import signature, Signature, Parameter
from typing import Literal
import ctypes

_DLL_TYPES = {
    '__cdecl': ctypes.CDLL,
    '__stdcall': ctypes.WinDLL
}

_dll_cache = {}

def _import_dll(name, calling_convention):
    try:
        return _dll_cache[name]
    except KeyError:
        dll = _DLL_TYPES[calling_convention](name)
        _dll_cache[name] = dll
        return dll

def _get_parameters_info(parameters):
    # (positional-only, positional or keyword, variadic positional)
    argtypes, defaults, argindex, i = [], [], {}, 0

    for param in parameters:
        if param.kind in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD):
            raise SyntaxError(f'parameter {param.name} must not be a {param.kind.description} parameter')

        if param.kind != Parameter.VAR_POSITIONAL:
            if param.annotation == Parameter.empty:
                raise SyntaxError(f'parameter {param.name} must have an annotation')

            argtypes.append(param.annotation)
            argindex[param.name] = i
            i += 1

            if param.kind == Parameter.POSITIONAL_OR_KEYWORD and param.default != Parameter.empty:
                defaults.append(param.default)
    return argtypes, defaults, argindex, i

def extern_func(
    dll_name: str,
    entry_point: str = ...,
    calling_convention: Literal['__cdecl', '__stdcall'] = '__cdecl'
):
    def func_decorator(func):
        sig = signature(func)
        argtypes, defaults, argindex, argcount = _get_parameters_info(sig.parameters.values())
        callback = func() if argcount == 0 else func(*range(argcount))

        dll = _import_dll(dll_name, calling_convention)
        func_name = entry_point if entry_point != Ellipsis else func.__name__
        func_ptr = getattr(dll, func_name)
        func_ptr.argtypes = argtypes
        func_ptr.restype = None if sig.return_annotation == Signature.empty else sig.return_annotation

        print(f'load extern function: {func_name}', f'callback: {callback}', sep=', ')

        @wraps(func)
        def func_wrap(*args, **kwargs):
            if len(args) < argcount:
                args = list(args)
                args.extend(defaults)

            for k, v in kwargs.items():
                if k in argindex:
                    args[argindex[k]] = v

            value = func_ptr(*args)
            return callback(value) if callable(callback) else value
        return func_wrap
    return func_decorator

Title
Subtitle
00:00 / 00:00
播放列表