Source code for pyfr.backends.openmp.provider

# -*- coding: utf-8 -*-

from ctypes import Structure, byref, c_void_p

from pyfr.backends.base import (BaseKernelProvider,
                                BasePointwiseKernelProvider, Kernel)
from pyfr.backends.openmp.generator import OpenMPKernelGenerator
from pyfr.nputil import npdtype_to_ctypestype
from pyfr.util import memoize


class OpenMPKernelFunction:
    def __init__(self, fun, argcls):
        self.fun = fun
        self.kargs = argcls()

    def set_arg(self, i, v):
        setattr(self.kargs, f'arg{i}', getattr(v, '_as_parameter_', v))

    def set_args(self, *args, start=0):
        for i, arg in enumerate(args, start=start):
            self.set_arg(i, arg)

    def __call__(self):
        self.fun(byref(self.kargs))


class OpenMPKernelProvider(BaseKernelProvider):
    @memoize
    def _get_arg_cls(self, argtypes):
        fields = [(f'arg{i}', npdtype_to_ctypestype(arg))
                  for i, arg in enumerate(argtypes)]

        return type('ArgStruct', (Structure,), {'_fields_': fields})

    @memoize
    def _build_library(self, src):
        return self.backend.compiler.build(src)

    def _build_function(self, name, src, argtypes, restype=None):
        lib = self._build_library(src)
        return lib.function(name, restype,
                            [npdtype_to_ctypestype(arg) for arg in argtypes])

    def _build_kernel(self, name, src, argtypes):
        lib = self._build_library(src)
        fun = lib.function(name, None, [c_void_p])

        return OpenMPKernelFunction(fun, self._get_arg_cls(tuple(argtypes)))


[docs]class OpenMPPointwiseKernelProvider(OpenMPKernelProvider, BasePointwiseKernelProvider): kernel_generator_cls = OpenMPKernelGenerator
[docs] def _instantiate_kernel(self, dims, fun, arglst, argmv): rtargs = [] # Process the arguments for i, k in enumerate(arglst): if isinstance(k, str): rtargs.append((i, k)) else: fun.set_arg(i, k) class PointwiseKernel(Kernel): def run(self, queue, **kwargs): for i, k in rtargs: fun.set_arg(i, kwargs[k]) fun() return PointwiseKernel(*argmv)