compute_paths_pybind11.cpp (7237B)
1 #include <pybind11/pybind11.h> 2 #include <pybind11/numpy.h> 3 4 #include <complex> 5 #include <vector> 6 #include <iostream> 7 8 #ifdef _WIN32 9 extern "C" { 10 #endif 11 #include "inc/compute_paths.h" 12 #include "inc/scene.h" 13 #include "inc/vec3.h" 14 #include "inc/ray.h" 15 #ifdef _WIN32 16 } 17 #endif 18 19 namespace py = pybind11; 20 21 // Helper function to create a complex numpy array from two C arrays 22 py::array_t<std::complex<float>> make_complex_array( 23 const float* real, 24 const float* imag, 25 size_t num_rx, 26 size_t num_tx, 27 size_t num_paths 28 ) { 29 size_t total_size = num_rx * num_tx * num_paths; 30 std::complex<float>* complex_data = new std::complex<float>[total_size]; 31 32 for (size_t i = 0; i < total_size; ++i) 33 complex_data[i] = std::complex<float>(real[i], imag[i]); 34 35 auto capsule = py::capsule(complex_data, [](void* ptr) { 36 delete[] static_cast<std::complex<float>*>(ptr); 37 }); 38 39 return py::array_t<std::complex<float>>( 40 {num_rx, num_tx, num_paths}, complex_data, capsule 41 ); 42 } 43 44 class ChannelInfoPython { 45 public: 46 size_t num_paths; 47 py::array_t<float> directions_rx; 48 py::array_t<float> directions_tx; 49 py::array_t<std::complex<float>> a_te; 50 py::array_t<std::complex<float>> a_tm; 51 py::array_t<float> tau; 52 py::array_t<float> freq_shift; 53 54 ChannelInfoPython(const ChannelInfo& chanInfo, size_t num_rx, size_t num_tx) { 55 num_paths = chanInfo.num_rays; 56 57 auto capsule_deleter = [](void* ptr) { delete[] static_cast<float*>(ptr); }; 58 59 directions_rx = py::array_t<float>( 60 std::vector<size_t>{num_rx, num_tx, (size_t)chanInfo.num_rays, 3}, 61 (float*)chanInfo.directions_rx, 62 py::capsule(chanInfo.directions_rx, capsule_deleter) 63 ); 64 directions_tx = py::array_t<float>( 65 std::vector<size_t>{num_rx, num_tx, (size_t)chanInfo.num_rays, 3}, 66 (float*)chanInfo.directions_tx, 67 py::capsule(chanInfo.directions_tx, capsule_deleter) 68 ); 69 70 a_te = make_complex_array( 71 chanInfo.a_te_re, 72 chanInfo.a_te_im, 73 num_rx, 74 num_tx, 75 chanInfo.num_rays 76 ); 77 a_tm = make_complex_array( 78 chanInfo.a_tm_re, 79 chanInfo.a_tm_im, 80 num_rx, 81 num_tx, 82 chanInfo.num_rays 83 ); 84 85 tau = py::array_t<float>( 86 {num_rx, num_tx, (size_t)chanInfo.num_rays}, 87 chanInfo.tau, 88 py::capsule(chanInfo.tau, capsule_deleter) 89 ); 90 91 freq_shift = py::array_t<float>( 92 {num_rx, num_tx, (size_t)chanInfo.num_rays}, 93 chanInfo.freq_shift, 94 py::capsule(chanInfo.freq_shift, capsule_deleter) 95 ); 96 } 97 }; 98 99 std::tuple<ChannelInfoPython, ChannelInfoPython> 100 compute_paths_wrapper( 101 const std::string &mesh_filepath, 102 py::array_t<float> rx_positions, 103 py::array_t<float> tx_positions, 104 py::array_t<float> rx_velocities, 105 py::array_t<float> tx_velocities, 106 float carrier_frequency, 107 unsigned long int num_rx, 108 unsigned long int num_tx, 109 unsigned long int num_paths, 110 unsigned long int num_bounces 111 ) { 112 // Prepare input arrays (this is a basic implementation, check shapes and memory layout) 113 py::buffer_info rx_pos_info = rx_positions.request(); 114 py::buffer_info tx_pos_info = tx_positions.request(); 115 py::buffer_info rx_vel_info = rx_velocities.request(); 116 py::buffer_info tx_vel_info = tx_velocities.request(); 117 118 // Load the scene 119 Scene scene = scene_load(mesh_filepath.c_str()); 120 121 // Output channel information 122 ChannelInfo chanInfo_los = { 123 .num_rays = 1, 124 .directions_rx = new Vec3[num_rx * num_tx], 125 .directions_tx = new Vec3[num_rx * num_tx], 126 .a_te_re = new float[num_rx * num_tx], 127 .a_te_im = new float[num_rx * num_tx], 128 .a_tm_re = new float[num_rx * num_tx], 129 .a_tm_im = new float[num_rx * num_tx], 130 .tau = new float[num_rx * num_tx], 131 .freq_shift = new float[num_rx * num_tx] 132 }; 133 ChannelInfo chanInfo_scat = { 134 .num_rays = (uint32_t)(num_bounces * num_paths), 135 .directions_rx = new Vec3[num_rx * num_tx * num_bounces * num_paths], 136 .directions_tx = new Vec3[num_rx * num_tx * num_bounces * num_paths], 137 .a_te_re = new float[num_rx * num_tx * num_bounces * num_paths], 138 .a_te_im = new float[num_rx * num_tx * num_bounces * num_paths], 139 .a_tm_re = new float[num_rx * num_tx * num_bounces * num_paths], 140 .a_tm_im = new float[num_rx * num_tx * num_bounces * num_paths], 141 .tau = new float[num_rx * num_tx * num_bounces * num_paths], 142 .freq_shift = new float[num_rx * num_tx * num_bounces * num_paths] 143 }; 144 // Output rays information 145 RaysInfo raysInfo_los = { 146 .rays = new Ray[num_rx * num_tx], 147 .rays_active = new uint8_t[num_rx * num_tx / 8 + 1] 148 }; 149 RaysInfo raysInfo_scat = { 150 .rays = new Ray[num_rx * num_tx * (num_bounces + 1) * num_paths], 151 .rays_active = new uint8_t[num_rx * num_tx * (num_bounces + 1) * num_paths / 8 + 1] 152 }; 153 154 // Call the C function 155 compute_paths( 156 &scene, // Scene 157 (Vec3*)rx_pos_info.ptr, // Rx positions 158 (Vec3*)tx_pos_info.ptr, // Tx positions 159 (Vec3*)rx_vel_info.ptr, // Rx velocities 160 (Vec3*)tx_vel_info.ptr, // Tx velocities 161 carrier_frequency, // Carrier frequency in GHz 162 (size_t)num_rx, 163 (size_t)num_tx, 164 (size_t)num_paths, 165 (size_t)num_bounces, 166 &chanInfo_los, // Channel info for LOS 167 &raysInfo_los, // Rays info for LOS 168 &chanInfo_scat, // Channel info for scatter 169 &raysInfo_scat // Rays info for scatter 170 ); 171 172 // Free the rays info (TODO remove when RaysInfo is made optional in compute_paths) 173 delete[] raysInfo_los.rays; 174 delete[] raysInfo_los.rays_active; 175 delete[] raysInfo_scat.rays; 176 delete[] raysInfo_scat.rays_active; 177 178 // Free the scene 179 free_scene(&scene); 180 181 // Wrap the results into Python objects 182 return std::make_tuple( 183 ChannelInfoPython(chanInfo_los, num_rx, num_tx), 184 ChannelInfoPython(chanInfo_scat, num_rx, num_tx) 185 ); 186 } 187 188 PYBIND11_MODULE(hermespy_rt, m) { 189 py::class_<ChannelInfoPython>(m, "ChannelInfo") 190 .def_readonly("num_paths", &ChannelInfoPython::num_paths) 191 .def_readonly("directions_rx", &ChannelInfoPython::directions_rx) 192 .def_readonly("directions_tx", &ChannelInfoPython::directions_tx) 193 .def_readonly("a_te", &ChannelInfoPython::a_te) 194 .def_readonly("a_tm", &ChannelInfoPython::a_tm) 195 .def_readonly("tau", &ChannelInfoPython::tau) 196 .def_readonly("freq_shift", &ChannelInfoPython::freq_shift); 197 198 // TODO write a proper docstring 199 m.def("compute_paths", &compute_paths_wrapper, "Compute gains and delays", 200 py::arg("mesh_filepath"), 201 py::arg("rx_positions"), 202 py::arg("tx_positions"), 203 py::arg("rx_velocities"), 204 py::arg("tx_velocities"), 205 py::arg("carrier_frequency"), 206 py::arg("num_rx"), 207 py::arg("num_tx"), 208 py::arg("num_paths"), 209 py::arg("num_bounces")); 210 }