commit 41f06a76cba7018c7aae32c934a3945f60b5e3cd
parent 5b5adef14a58975360f9a790aaa78ae0090a2fab
Author: Egor Achkasov <eaachkasov@edu.hse.ru>
Date: Wed, 4 Dec 2024 11:57:12 +0100
add pybind11 binding
Diffstat:
3 files changed, 73 insertions(+), 111 deletions(-)
diff --git a/compute_paths_pybind11.cpp b/compute_paths_pybind11.cpp
@@ -0,0 +1,73 @@
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+#include "compute_paths.h"
+
+namespace py = pybind11;
+
+std::tuple< py::array_t<float>, py::array_t<float> >
+compute_paths_wrapper(
+ const std::string &mesh_filepath,
+ py::array_t<float> rx_positions,
+ py::array_t<float> tx_positions,
+ py::array_t<float> rx_velocities,
+ py::array_t<float> tx_velocities,
+ float carrier_frequency,
+ int num_rx,
+ int num_tx,
+ int num_paths,
+ int num_bounces
+) {
+ // Prepare input arrays (this is a basic implementation, check shapes and memory layout)
+ py::buffer_info rx_pos_info = rx_positions.request();
+ py::buffer_info tx_pos_info = tx_positions.request();
+ py::buffer_info rx_vel_info = rx_velocities.request();
+ py::buffer_info tx_vel_info = tx_velocities.request();
+
+ // Output arrays
+ // float *a_im = new float[num_paths];
+ // float *a_re = new float[num_paths];
+ float *a = new float[num_bounces * num_tx * num_paths * 3];
+ float *tau = new float[num_paths];
+
+ // Call the C function
+ compute_paths(
+ mesh_filepath.c_str(),
+ (float*)rx_pos_info.ptr, // Rx positions
+ (float*)tx_pos_info.ptr, // Tx positions
+ (float*)rx_vel_info.ptr, // Rx velocities
+ (float*)tx_vel_info.ptr, // Tx velocities
+ carrier_frequency,
+ (size_t)num_rx,
+ (size_t)num_tx,
+ (size_t)num_paths,
+ (size_t)num_bounces,
+ a,
+ tau
+ );
+
+ // Convert output arrays into numpy arrays for easy use in Python
+ //py::array_t<float> a_array = py::array_t<float>(num_paths, a);
+ py::array_t<float> a_array = py::array_t<float>({num_bounces, num_tx, num_paths, 3}, a);
+ py::array_t<float> tau_array = py::array_t<float>(num_paths, tau);
+
+ // Deallocate arrays
+ // delete[] a;
+ // delete[] tau;
+
+ // Return the results as a tuple (gains, delays)
+ return std::make_tuple(a_array, tau_array);
+}
+
+PYBIND11_MODULE(rt, m) {
+ m.def("compute_paths", &compute_paths_wrapper, "Compute gains and delays in PLY scene",
+ py::arg("mesh_filepath"),
+ py::arg("rx_positions"),
+ py::arg("tx_positions"),
+ py::arg("rx_velocities"),
+ py::arg("tx_velocities"),
+ py::arg("carrier_frequency"),
+ py::arg("num_rx"),
+ py::arg("num_tx"),
+ py::arg("num_paths"),
+ py::arg("num_bounces"));
+}
diff --git a/py_compute_paths.c b/py_compute_paths.c
@@ -1,94 +0,0 @@
-/* This file contains a python binding for compute_paths function. */
-
-#include <Python.h>
-#include <numpy/arrayobject.h> // Numpy C API
-#include "compute_paths.h" // Assuming this header contains the compute_paths function declaration
-
-// Wrapper function for Python to call compute_paths
-static PyObject* py_compute_paths(PyObject* self, PyObject* args) {
- const char* mesh_filepath;
- PyObject* rx_positions_obj;
- PyObject* tx_positions_obj;
- PyObject* rx_velocities_obj;
- PyObject* tx_velocities_obj;
- float carrier_frequency;
- int num_rx, num_tx, num_paths, num_bounces;
-
- if (!PyArg_ParseTuple(args, "sO&O&O&O&fiiii",
- &mesh_filepath,
- &rx_positions_obj,
- &tx_positions_obj,
- &rx_velocities_obj,
- &tx_velocities_obj,
- &carrier_frequency,
- &num_rx,
- &num_tx,
- &num_paths,
- &num_bounces))
- return NULL;
-
- if (!PyArray_Check(rx_positions_obj) || !PyArray_Check(tx_positions_obj) ||
- !PyArray_Check(rx_velocities_obj) || !PyArray_Check(tx_velocities_obj)) {
- PyErr_SetString(PyExc_TypeError, "All inputs must be numpy arrays.");
- return NULL;
- }
-
- npy_float32* rx_positions = (npy_float32*)PyArray_DATA((PyArrayObject*)rx_positions_obj);
- npy_float32* tx_positions = (npy_float32*)PyArray_DATA((PyArrayObject*)tx_positions_obj);
- npy_float32* rx_velocities = (npy_float32*)PyArray_DATA((PyArrayObject*)rx_velocities_obj);
- npy_float32* tx_velocities = (npy_float32*)PyArray_DATA((PyArrayObject*)tx_velocities_obj);
-
- // Output arrays for gains and delays
- float* a = (float*)malloc(num_paths * sizeof(float));
- float* tau = (float*)malloc(num_paths * sizeof(float));
-
- if (!a || !tau) {
- PyErr_NoMemory();
- return NULL;
- }
-
- // Call the C compute_paths function
- compute_paths(mesh_filepath,
- rx_positions, tx_positions,
- rx_velocities, tx_velocities,
- carrier_frequency,
- num_rx, num_tx, num_paths, num_bounces,
- a, tau);
-
- // Prepare the output tuple (gains, delays)
- //npy_intp dims[1] = {num_paths};
- npy_intp dims[] = {num_bounces, num_tx, num_paths};
- PyObject* gains_array = PyArray_SimpleNewFromData(1, dims, NPY_FLOAT32, a);
- PyObject* delays_array = PyArray_SimpleNewFromData(1, dims, NPY_INT32, tau);
-
- // Return the tuple of numpy arrays
- PyObject* result = PyTuple_Pack(2, gains_array, delays_array);
-
- // Clean up
- Py_DECREF(gains_array);
- Py_DECREF(delays_array);
-
- return result;
-}
-
-// Method table
-/* TODO add args and returns */
-static PyMethodDef module_methods[] = {
- {"compute_paths", py_compute_paths, METH_VARARGS, "Compute the paths between transmitters and receivers."},
- {NULL, NULL, 0, NULL}
-};
-
-// Module definition
-static struct PyModuleDef module_definition = {
- PyModuleDef_HEAD_INIT,
- "compute_paths_module", // Module name
- "C extension for computing paths.", // Module docstring
- -1,
- module_methods // Method table
-};
-
-// Module initialization function
-PyMODINIT_FUNC PyInit_compute_paths_module(void) {
- import_array(); // Initialize the numpy API
- return PyModule_Create(&module_definition);
-}
diff --git a/setup.py b/setup.py
@@ -1,17 +0,0 @@
-from setuptools import setup, Extension
-import numpy
-
-module = Extension(
- 'compute_paths_module',
- sources=['compute_paths.c', 'py_compute_paths.c'],
- include_dirs=[numpy.get_include()],
- libraries=["m", "xml2"],
- extra_compile_args=['-O3'],
-)
-
-setup(
- name='compute_paths_module',
- version='1.0',
- description='TODO write description.', # TODO write description
- ext_modules=[module],
-)