hermespy-rt

Minimalistic signal processing ray-tracer in C
git clone https://git.ea.contact/hermespy-rt
Log | Files | Refs

commit 41f06a76cba7018c7aae32c934a3945f60b5e3cd
parent 5b5adef14a58975360f9a790aaa78ae0090a2fab
Author: Egor Achkasov <eaachkasov@edu.hse.ru>
Date:   Wed,  4 Dec 2024 11:57:12 +0100

add pybind11 binding

Diffstat:
Acompute_paths_pybind11.cpp | 73+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Dpy_compute_paths.c | 94-------------------------------------------------------------------------------
Dsetup.py | 17-----------------
3 files changed, 73 insertions(+), 111 deletions(-)

diff --git a/compute_paths_pybind11.cpp b/compute_paths_pybind11.cpp @@ -0,0 +1,73 @@ +#include <pybind11/pybind11.h> +#include <pybind11/numpy.h> +#include "compute_paths.h" + +namespace py = pybind11; + +std::tuple< py::array_t<float>, py::array_t<float> > +compute_paths_wrapper( + const std::string &mesh_filepath, + py::array_t<float> rx_positions, + py::array_t<float> tx_positions, + py::array_t<float> rx_velocities, + py::array_t<float> tx_velocities, + float carrier_frequency, + int num_rx, + int num_tx, + int num_paths, + int num_bounces +) { + // Prepare input arrays (this is a basic implementation, check shapes and memory layout) + py::buffer_info rx_pos_info = rx_positions.request(); + py::buffer_info tx_pos_info = tx_positions.request(); + py::buffer_info rx_vel_info = rx_velocities.request(); + py::buffer_info tx_vel_info = tx_velocities.request(); + + // Output arrays + // float *a_im = new float[num_paths]; + // float *a_re = new float[num_paths]; + float *a = new float[num_bounces * num_tx * num_paths * 3]; + float *tau = new float[num_paths]; + + // Call the C function + compute_paths( + mesh_filepath.c_str(), + (float*)rx_pos_info.ptr, // Rx positions + (float*)tx_pos_info.ptr, // Tx positions + (float*)rx_vel_info.ptr, // Rx velocities + (float*)tx_vel_info.ptr, // Tx velocities + carrier_frequency, + (size_t)num_rx, + (size_t)num_tx, + (size_t)num_paths, + (size_t)num_bounces, + a, + tau + ); + + // Convert output arrays into numpy arrays for easy use in Python + //py::array_t<float> a_array = py::array_t<float>(num_paths, a); + py::array_t<float> a_array = py::array_t<float>({num_bounces, num_tx, num_paths, 3}, a); + py::array_t<float> tau_array = py::array_t<float>(num_paths, tau); + + // Deallocate arrays + // delete[] a; + // delete[] tau; + + // Return the results as a tuple (gains, delays) + return std::make_tuple(a_array, tau_array); +} + +PYBIND11_MODULE(rt, m) { + m.def("compute_paths", &compute_paths_wrapper, "Compute gains and delays in PLY scene", + py::arg("mesh_filepath"), + py::arg("rx_positions"), + py::arg("tx_positions"), + py::arg("rx_velocities"), + py::arg("tx_velocities"), + py::arg("carrier_frequency"), + py::arg("num_rx"), + py::arg("num_tx"), + py::arg("num_paths"), + py::arg("num_bounces")); +} diff --git a/py_compute_paths.c b/py_compute_paths.c @@ -1,94 +0,0 @@ -/* This file contains a python binding for compute_paths function. */ - -#include <Python.h> -#include <numpy/arrayobject.h> // Numpy C API -#include "compute_paths.h" // Assuming this header contains the compute_paths function declaration - -// Wrapper function for Python to call compute_paths -static PyObject* py_compute_paths(PyObject* self, PyObject* args) { - const char* mesh_filepath; - PyObject* rx_positions_obj; - PyObject* tx_positions_obj; - PyObject* rx_velocities_obj; - PyObject* tx_velocities_obj; - float carrier_frequency; - int num_rx, num_tx, num_paths, num_bounces; - - if (!PyArg_ParseTuple(args, "sO&O&O&O&fiiii", - &mesh_filepath, - &rx_positions_obj, - &tx_positions_obj, - &rx_velocities_obj, - &tx_velocities_obj, - &carrier_frequency, - &num_rx, - &num_tx, - &num_paths, - &num_bounces)) - return NULL; - - if (!PyArray_Check(rx_positions_obj) || !PyArray_Check(tx_positions_obj) || - !PyArray_Check(rx_velocities_obj) || !PyArray_Check(tx_velocities_obj)) { - PyErr_SetString(PyExc_TypeError, "All inputs must be numpy arrays."); - return NULL; - } - - npy_float32* rx_positions = (npy_float32*)PyArray_DATA((PyArrayObject*)rx_positions_obj); - npy_float32* tx_positions = (npy_float32*)PyArray_DATA((PyArrayObject*)tx_positions_obj); - npy_float32* rx_velocities = (npy_float32*)PyArray_DATA((PyArrayObject*)rx_velocities_obj); - npy_float32* tx_velocities = (npy_float32*)PyArray_DATA((PyArrayObject*)tx_velocities_obj); - - // Output arrays for gains and delays - float* a = (float*)malloc(num_paths * sizeof(float)); - float* tau = (float*)malloc(num_paths * sizeof(float)); - - if (!a || !tau) { - PyErr_NoMemory(); - return NULL; - } - - // Call the C compute_paths function - compute_paths(mesh_filepath, - rx_positions, tx_positions, - rx_velocities, tx_velocities, - carrier_frequency, - num_rx, num_tx, num_paths, num_bounces, - a, tau); - - // Prepare the output tuple (gains, delays) - //npy_intp dims[1] = {num_paths}; - npy_intp dims[] = {num_bounces, num_tx, num_paths}; - PyObject* gains_array = PyArray_SimpleNewFromData(1, dims, NPY_FLOAT32, a); - PyObject* delays_array = PyArray_SimpleNewFromData(1, dims, NPY_INT32, tau); - - // Return the tuple of numpy arrays - PyObject* result = PyTuple_Pack(2, gains_array, delays_array); - - // Clean up - Py_DECREF(gains_array); - Py_DECREF(delays_array); - - return result; -} - -// Method table -/* TODO add args and returns */ -static PyMethodDef module_methods[] = { - {"compute_paths", py_compute_paths, METH_VARARGS, "Compute the paths between transmitters and receivers."}, - {NULL, NULL, 0, NULL} -}; - -// Module definition -static struct PyModuleDef module_definition = { - PyModuleDef_HEAD_INIT, - "compute_paths_module", // Module name - "C extension for computing paths.", // Module docstring - -1, - module_methods // Method table -}; - -// Module initialization function -PyMODINIT_FUNC PyInit_compute_paths_module(void) { - import_array(); // Initialize the numpy API - return PyModule_Create(&module_definition); -} diff --git a/setup.py b/setup.py @@ -1,17 +0,0 @@ -from setuptools import setup, Extension -import numpy - -module = Extension( - 'compute_paths_module', - sources=['compute_paths.c', 'py_compute_paths.c'], - include_dirs=[numpy.get_include()], - libraries=["m", "xml2"], - extra_compile_args=['-O3'], -) - -setup( - name='compute_paths_module', - version='1.0', - description='TODO write description.', # TODO write description - ext_modules=[module], -)