hermespy-rt

Minimalistic signal processing ray-tracer in C
git clone https://git.ea.contact/hermespy-rt
Log | Files | Refs

compute_paths_pybind11.cpp (7237B)


      1 #include <pybind11/pybind11.h>
      2 #include <pybind11/numpy.h>
      3 
      4 #include <complex>
      5 #include <vector>
      6 #include <iostream>
      7 
      8 #ifdef _WIN32
      9 extern "C" {
     10 #endif
     11     #include "inc/compute_paths.h"
     12     #include "inc/scene.h"
     13     #include "inc/vec3.h"
     14     #include "inc/ray.h"
     15 #ifdef _WIN32
     16 }
     17 #endif
     18 
     19 namespace py = pybind11;
     20 
     21 // Helper function to create a complex numpy array from two C arrays
     22 py::array_t<std::complex<float>> make_complex_array(
     23     const float* real,
     24     const float* imag,
     25     size_t num_rx,
     26     size_t num_tx,
     27     size_t num_paths
     28 ) {
     29     size_t total_size = num_rx * num_tx * num_paths;
     30     std::complex<float>* complex_data = new std::complex<float>[total_size];
     31 
     32     for (size_t i = 0; i < total_size; ++i)
     33         complex_data[i] = std::complex<float>(real[i], imag[i]);
     34 
     35     auto capsule = py::capsule(complex_data, [](void* ptr) {
     36         delete[] static_cast<std::complex<float>*>(ptr);
     37     });
     38     
     39     return py::array_t<std::complex<float>>(
     40         {num_rx, num_tx, num_paths}, complex_data, capsule
     41     );
     42 }
     43 
     44 class ChannelInfoPython {
     45 public:
     46     size_t num_paths;
     47     py::array_t<float> directions_rx;
     48     py::array_t<float> directions_tx;
     49     py::array_t<std::complex<float>> a_te;
     50     py::array_t<std::complex<float>> a_tm;
     51     py::array_t<float> tau;
     52     py::array_t<float> freq_shift;
     53 
     54     ChannelInfoPython(const ChannelInfo& chanInfo, size_t num_rx, size_t num_tx) {
     55         num_paths = chanInfo.num_rays;
     56 
     57         auto capsule_deleter = [](void* ptr) { delete[] static_cast<float*>(ptr); };
     58 
     59         directions_rx = py::array_t<float>(
     60             std::vector<size_t>{num_rx, num_tx, (size_t)chanInfo.num_rays, 3},
     61             (float*)chanInfo.directions_rx,
     62             py::capsule(chanInfo.directions_rx, capsule_deleter)
     63         );
     64         directions_tx = py::array_t<float>(
     65             std::vector<size_t>{num_rx, num_tx, (size_t)chanInfo.num_rays, 3},
     66             (float*)chanInfo.directions_tx,
     67             py::capsule(chanInfo.directions_tx, capsule_deleter)
     68         );
     69 
     70         a_te = make_complex_array(
     71             chanInfo.a_te_re,
     72             chanInfo.a_te_im,
     73             num_rx,
     74             num_tx,
     75             chanInfo.num_rays
     76         );
     77         a_tm = make_complex_array(
     78             chanInfo.a_tm_re,
     79             chanInfo.a_tm_im,
     80             num_rx,
     81             num_tx,
     82             chanInfo.num_rays
     83         );
     84 
     85         tau = py::array_t<float>(
     86             {num_rx, num_tx, (size_t)chanInfo.num_rays},
     87             chanInfo.tau,
     88             py::capsule(chanInfo.tau, capsule_deleter)
     89         );
     90 
     91         freq_shift = py::array_t<float>(
     92             {num_rx, num_tx, (size_t)chanInfo.num_rays},
     93             chanInfo.freq_shift,
     94             py::capsule(chanInfo.freq_shift, capsule_deleter)
     95         );
     96     }
     97 };
     98 
     99 std::tuple<ChannelInfoPython, ChannelInfoPython>
    100 compute_paths_wrapper(
    101     const std::string &mesh_filepath,
    102     py::array_t<float> rx_positions,
    103     py::array_t<float> tx_positions,
    104     py::array_t<float> rx_velocities,
    105     py::array_t<float> tx_velocities,
    106     float carrier_frequency,
    107     unsigned long int num_rx,
    108     unsigned long int num_tx,
    109     unsigned long int num_paths,
    110     unsigned long int num_bounces
    111 ) {
    112     // Prepare input arrays (this is a basic implementation, check shapes and memory layout)
    113     py::buffer_info rx_pos_info = rx_positions.request();
    114     py::buffer_info tx_pos_info = tx_positions.request();
    115     py::buffer_info rx_vel_info = rx_velocities.request();
    116     py::buffer_info tx_vel_info = tx_velocities.request();
    117 
    118     // Load the scene
    119     Scene scene = scene_load(mesh_filepath.c_str());
    120 
    121     // Output channel information
    122     ChannelInfo chanInfo_los = {
    123         .num_rays = 1,
    124         .directions_rx = new Vec3[num_rx * num_tx],
    125         .directions_tx = new Vec3[num_rx * num_tx],
    126         .a_te_re = new float[num_rx * num_tx],
    127         .a_te_im = new float[num_rx * num_tx],
    128         .a_tm_re = new float[num_rx * num_tx],
    129         .a_tm_im = new float[num_rx * num_tx],
    130         .tau = new float[num_rx * num_tx],
    131         .freq_shift = new float[num_rx * num_tx]
    132     };
    133     ChannelInfo chanInfo_scat = {
    134         .num_rays = (uint32_t)(num_bounces * num_paths),
    135         .directions_rx = new Vec3[num_rx * num_tx * num_bounces * num_paths],
    136         .directions_tx = new Vec3[num_rx * num_tx * num_bounces * num_paths],
    137         .a_te_re = new float[num_rx * num_tx * num_bounces * num_paths],
    138         .a_te_im = new float[num_rx * num_tx * num_bounces * num_paths],
    139         .a_tm_re = new float[num_rx * num_tx * num_bounces * num_paths],
    140         .a_tm_im = new float[num_rx * num_tx * num_bounces * num_paths],
    141         .tau = new float[num_rx * num_tx * num_bounces * num_paths],
    142         .freq_shift = new float[num_rx * num_tx * num_bounces * num_paths]
    143     };
    144     // Output rays information
    145     RaysInfo raysInfo_los = {
    146         .rays = new Ray[num_rx * num_tx],
    147         .rays_active = new uint8_t[num_rx * num_tx / 8 + 1]
    148     };
    149     RaysInfo raysInfo_scat = {
    150         .rays = new Ray[num_rx * num_tx * (num_bounces + 1) * num_paths],
    151         .rays_active = new uint8_t[num_rx * num_tx * (num_bounces + 1) * num_paths / 8 + 1]
    152     };
    153 
    154     // Call the C function
    155     compute_paths(
    156         &scene,  // Scene
    157         (Vec3*)rx_pos_info.ptr,  // Rx positions
    158         (Vec3*)tx_pos_info.ptr,  // Tx positions
    159         (Vec3*)rx_vel_info.ptr,  // Rx velocities
    160         (Vec3*)tx_vel_info.ptr,  // Tx velocities
    161         carrier_frequency,  // Carrier frequency in GHz
    162         (size_t)num_rx,
    163         (size_t)num_tx,
    164         (size_t)num_paths,
    165         (size_t)num_bounces,
    166         &chanInfo_los,  // Channel info for LOS
    167         &raysInfo_los,  // Rays info for LOS
    168         &chanInfo_scat,  // Channel info for scatter
    169         &raysInfo_scat  // Rays info for scatter
    170     );
    171 
    172     // Free the rays info (TODO remove when RaysInfo is made optional in compute_paths)
    173     delete[] raysInfo_los.rays;
    174     delete[] raysInfo_los.rays_active;
    175     delete[] raysInfo_scat.rays;
    176     delete[] raysInfo_scat.rays_active;
    177 
    178     // Free the scene
    179     free_scene(&scene);
    180 
    181     // Wrap the results into Python objects
    182     return std::make_tuple(
    183         ChannelInfoPython(chanInfo_los, num_rx, num_tx),
    184         ChannelInfoPython(chanInfo_scat, num_rx, num_tx)
    185     );
    186 }
    187 
    188 PYBIND11_MODULE(hermespy_rt, m) {
    189     py::class_<ChannelInfoPython>(m, "ChannelInfo")
    190         .def_readonly("num_paths", &ChannelInfoPython::num_paths)
    191         .def_readonly("directions_rx", &ChannelInfoPython::directions_rx)
    192         .def_readonly("directions_tx", &ChannelInfoPython::directions_tx)
    193         .def_readonly("a_te", &ChannelInfoPython::a_te)
    194         .def_readonly("a_tm", &ChannelInfoPython::a_tm)
    195         .def_readonly("tau", &ChannelInfoPython::tau)
    196         .def_readonly("freq_shift", &ChannelInfoPython::freq_shift);
    197 
    198     // TODO write a proper docstring
    199     m.def("compute_paths", &compute_paths_wrapper, "Compute gains and delays",
    200           py::arg("mesh_filepath"),
    201           py::arg("rx_positions"),
    202           py::arg("tx_positions"),
    203           py::arg("rx_velocities"),
    204           py::arg("tx_velocities"),
    205           py::arg("carrier_frequency"),
    206           py::arg("num_rx"),
    207           py::arg("num_tx"),
    208           py::arg("num_paths"),
    209           py::arg("num_bounces"));
    210 }