tf_paths.py (13281B)
1 # flake8: noqa: E501 # Ignore long lines 2 3 # Implementation following the 4 # "Learning Radio Environments by Differentiable Ray Tracing" 5 # 2024 NVIDIA paper 6 # http://dx.doi.org/10.1109/TMLCN.2024.3474639 7 8 # What this script does: 9 # - Loads the scene and measurements 10 # - Merges seconds and milliseconds (thus getting a total milliseconds dimension, further referred to as "samples") 11 # - Traces paths for each sample (picking random g_num_samples samples) with sionna rt Candidatecalculator and ImageMethod, generating a dataset of g_num_samples PathBuffers 12 # - Trains the scene material parameters according to Algorithm 1 in the paper 13 # - Saves the losses and the trainable parameters (omega) to files (losses.npy and omegas.npy) 14 15 from measurements import load as measurements_load 16 17 import numpy as np 18 import tensorflow as tf 19 import sionna.rt as srt 20 from sionna.constants import PI 21 22 from typing import Tuple, List 23 import os 24 25 # Config 26 # Dataset 27 g_cfg_meas_round = 11 # Measurements round to load 28 g_cfg_num_samples = 10 # Number of samples to use from the dataset. Picked randomly. 29 g_cfg_batch_size = 3 # Number of samples in each batch (B) 30 # Fit 31 g_cfg_num_embeddings = 2 # L in the paper 32 g_cfg_learning_rate = 0.01 # beta in the paper 33 g_cfg_decay = tf.Variable(0.001, dtype=tf.float32) # delta in the paper 34 g_cfg_max_iterations = 5 # max number of iterations for the optimization 35 g_cfg_convergence_threshold = 1e-6 36 g_cfg_eps = 1e-6 # For numerical gradient calculation 37 g_cfg_keep_material_params = False # If True, start fitting from the current material parameters 38 # RF 39 g_cfg_central_frequency = 3.75e9 # Set into scene.frequency 40 g_cfg_subcarrier_spacing = 100e6 # Delta_f in the paper 41 g_cfg_sampling_rate = 1. / 1000. 42 # Spacial 43 g_cfg_rx_locations = [ 44 [4, 5, 2.2], 45 [4, 5, 2.2], 46 [20.5, 8, 3], 47 [20.5, 8, 3], 48 [22, 0.5, 2.85], 49 [22, 0.5, 2.85] 50 ] 51 g_cfg_rx_orientations = [ 52 [0, 0, -PI/2], 53 [0, 0, 0], 54 [0, 0, -PI/2], 55 [0, 0, 0], 56 [0, 0, PI/2], 57 [0, 0, 0] 58 ] 59 g_cfg_tx_position_initial = np.array([5.5, 5, 0.5]) # Initial TX position 60 g_cfg_tx_velocity = np.array([0.0005, 0, 0]) # meters per millisecond 61 # Sionna 62 g_cfg_max_depth = 3 63 g_cfg_trace_paths_num_samples = 1000 # Number of samples to trace paths for 64 g_cfg_synthetic_array = False 65 g_cfg_los = True 66 g_cfg_specular_reflection = True 67 g_cfg_diffuse_reflection = True 68 g_cfg_refraction = True 69 # Output 70 g_cfg_record_loss = True 71 g_cfg_record_omegas = True 72 73 # Globals 74 g_scene: srt.Scene 75 g_trainable_materials: List[srt.RadioMaterial] = [] # List of trainable materials in the scene 76 g_bandwidth: float # Bandwidth of the signal (W in the paper) 77 78 79 @tf.function 80 def calc_material_parameters(omega: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: 81 """Calculate the material parameters from the read-out vectors and embeddings. 82 83 Args: 84 omega: Read-out and embeddings tensor. Shape (2, 4, num_materials, num_embeddings) 85 86 Returns: 87 Tuple containing: 88 - cond: Conductivity of the materials. 89 - perm: Relative permittivity of the materials. 90 - scat: Scattering coefficient of the materials. 91 - xpd: XPD coefficient of the materials. 92 """ 93 dot = tf.reduce_sum(omega[0] * omega[1], axis=-1) 94 cond = tf.exp(dot[0]) 95 perm = tf.exp(dot[1]) + 1. 96 scat = tf.sigmoid(dot[2]) 97 xpd = tf.sigmoid(dot[3]) 98 return cond, perm, scat, xpd 99 100 101 @tf.function 102 def calc_omega( 103 cond: tf.Tensor, 104 perm: tf.Tensor, 105 scat: tf.Tensor, 106 xpd: tf.Tensor, 107 ) -> tf.Tensor: 108 """Calculate the omega vector from the material parameters. 109 110 Args: 111 cond: Conductivity of the materials. 112 perm: Relative permittivity of the materials. 113 scat: Scattering coefficient of the materials. 114 xpd: XPD coefficient of the materials. 115 116 Returns: 117 The omega vector (shape: (2, 4, num_materials, num_embeddings,)). 118 """ 119 num_materials = tf.shape(cond)[0] 120 dot0 = tf.math.log(cond) 121 dot1 = tf.math.log(perm - 1.) 122 dot2 = -tf.math.log(1 / scat - 1) 123 dot3 = -tf.math.log(1 / xpd - 1) 124 dot = tf.stack([dot0, dot1, dot2, dot3], axis=0) 125 w = tf.random.uniform((4, num_materials, g_cfg_num_embeddings), minval=-1, maxval=1, dtype=tf.float32) 126 v = dot[:, :, tf.newaxis] / w 127 return tf.stack([v, w], axis=0) 128 129 130 @tf.function 131 def set_omega(omega: tf.Tensor) -> None: 132 """Set the trainable parameters in the scene from omega.""" 133 cond, perm, scat, xpd = calc_material_parameters(omega) 134 for idx, mat in enumerate(g_trainable_materials): 135 g_scene._radio_materials[mat.name].conductivity = tf.gather(cond, idx) 136 g_scene._radio_materials[mat.name].relative_permittivity = tf.gather(perm, idx) 137 g_scene._radio_materials[mat.name].scattering_coefficient = tf.gather(scat, idx) 138 g_scene._radio_materials[mat.name].xpd_coefficient = tf.gather(xpd, idx) 139 140 141 g_alpha = tf.Variable(6e-9, dtype=tf.float32) # Initial α_i 142 @tf.function 143 def objective(cirs_batch, batch_indices, omega, paths_buffers) -> tf.Variable: 144 """Compute the objective function for the optimization. 145 Returns the average loss over the batch as in (38) of the paper. 146 """ 147 set_omega(omega) 148 p_hat = tf.TensorArray(dtype=tf.float32, size=tf.shape(batch_indices)[0]) 149 tau_hat = tf.TensorArray(dtype=tf.float32, size=tf.shape(batch_indices)[0]) 150 for i in tf.range(tf.shape(batch_indices)[0]): 151 ind = batch_indices[i] 152 def get_a_tau(index): 153 return g_scene.compute_fields( 154 *paths_buffers[int(index)], 155 check_scene=False, 156 scat_random_phases=False, 157 testing=False, 158 ).cir( 159 los=g_cfg_los, 160 reflection=g_cfg_specular_reflection, 161 diffraction=g_cfg_refraction, 162 scattering=g_cfg_diffuse_reflection, 163 ris=False, 164 cluster_ris_paths=False, 165 num_paths=None, 166 ) 167 a, tau = tf.py_function( 168 get_a_tau, 169 [ind], 170 Tout=[tf.complex64, tf.float32] 171 ) 172 # Record 173 p_val = tf.reduce_sum(tf.abs(a) ** 2) 174 tau_val = tf.reduce_mean(tau) 175 p_hat = p_hat.write(i, p_val) 176 tau_hat = tau_hat.write(i, tau_val) 177 p_hat = p_hat.stack() 178 tau_hat = tau_hat.stack() 179 180 # Scale the batch: h_b ← √α_i * h_b 181 p = tf.reduce_sum(tf.abs(cirs_batch) ** 2, axis=(0, 1, 2)) 182 alpha_hat = tf.reduce_sum(p[..., tf.newaxis] * p_hat) / (tf.reduce_sum(p_hat ** 2) + 1e-8) 183 alpha = g_cfg_decay * g_alpha + (1 - g_cfg_decay) * alpha_hat 184 alpha = tf.cast(alpha, tf.complex64) 185 cirs_scaled = tf.sqrt(alpha) * cirs_batch 186 p = tf.reduce_sum(tf.abs(cirs_scaled) ** 2, axis=(0, 1, 2)) 187 188 # Approximate tau_RMS as in (40, 41) 189 total_channel_gain = tf.reduce_sum(p) 190 p = p / (total_channel_gain + 1e-8) 191 ls = np.arange(512) - 512 // 2 # Subcarrier indices 192 tau_overline = tf.reduce_sum(ls[:, None] * p, axis=0) 193 tau_rms = (ls - tau_overline) / g_bandwidth 194 tau_rms = tf.sqrt(tf.reduce_sum((tau_rms ** 2)[:, None] * p)) 195 196 # Compute sample loss L_b as in (38) 197 def smape(x, y): 198 return tf.abs(x - y) / (tf.abs(x) + tf.abs(y) + 1e-8) 199 losses = tf.reduce_mean(smape(p[..., tf.newaxis], p_hat), axis=0) + smape(tau_rms, tau_hat) 200 # Average loss over the batch: L = (1/B) * sum(L_b) 201 return tf.reduce_mean(losses) 202 203 204 if __name__ == "__main__": 205 # Initialization 206 # Scene 207 print("Loading scene... ", end='', flush=True) 208 g_scene = srt.load_scene( 209 os.path.join( 210 os.path.dirname(__file__), 211 "..", 212 "scene", 213 "scene.xml") 214 ) 215 g_scene.frequency = g_cfg_central_frequency 216 g_scene.tx_array = srt.PlanarArray( 217 num_rows=1, 218 num_cols=1, 219 vertical_spacing=0.5, 220 horizontal_spacing=0.5, 221 pattern="dipole", 222 polarization="V", 223 ) 224 g_scene.rx_array = srt.PlanarArray( 225 num_rows=1, 226 num_cols=1, 227 vertical_spacing=0.5, 228 horizontal_spacing=0.5, 229 pattern="dipole", 230 polarization="VH", 231 ) 232 for rx_idx, rx_loc in enumerate(g_cfg_rx_locations): 233 rx = srt.Receiver( 234 name=f"rx{rx_idx}", 235 position=rx_loc, 236 orientation=g_cfg_rx_orientations[rx_idx], 237 ) 238 g_scene.add(rx) 239 tx = srt.Transmitter( 240 name="tx", 241 position=g_cfg_tx_position_initial, 242 ) 243 g_scene.add(tx) 244 print("Done") 245 246 # Make trainable materials 247 for mat in g_scene.radio_materials.values(): 248 if not mat.is_used: 249 continue 250 # Create new trainable material 251 if g_cfg_keep_material_params: 252 # Start from the current material parameters 253 new_mat = srt.RadioMaterial( 254 mat.name + "_train", 255 relative_permittivity=mat.relative_permittivity, 256 conductivity=mat.conductivity, 257 scattering_coefficient=mat.scattering_coefficient, 258 xpd_coefficient=mat.xpd_coefficient 259 ) 260 else: 261 new_mat = srt.RadioMaterial(mat.name + "_train", 262 relative_permittivity=3.0, 263 conductivity=0.1) 264 g_scene.add(new_mat) 265 g_trainable_materials.append(new_mat) 266 267 # Assign trainable materials to the corresponding objects 268 for obj in g_scene.objects.values(): 269 obj.radio_material = obj.radio_material.name + "_train" 270 271 # Load the measurements 272 print("Loading measurements... ", end='', flush=True) 273 data_dirpath = os.path.join( 274 os.path.dirname(__file__), 275 "..", "..", "data" 276 ) 277 # [num_aps, num_rfs, num_seconds, num_subcarriers, 1000] 278 cirs = measurements_load(data_dirpath, g_cfg_meas_round) 279 # Cast to complex64 280 cirs = cirs.astype(np.complex64) 281 # Move num_seconds to 0 and 1000 to 1 282 # [num_seconds, 1000, num_aps, num_rfs, num_subcarriers] 283 cirs = cirs.transpose((2, 4, 0, 1, 3)) 284 # Unite num_seconds and 1000 285 # [num_seconds * 1000, num_aps, num_rfs, num_subcarriers] 286 cirs = cirs.reshape((cirs.shape[0] * cirs.shape[1], cirs.shape[2], cirs.shape[3], cirs.shape[4])) 287 g_bandwidth = cirs.shape[-2] * g_cfg_subcarrier_spacing # W in the paper 288 print("Done") 289 290 # Make the dataset 291 # Contains: 292 # - CIRs (shape (num_samples, num_aps, num_rfs, num_subcarriers)) 293 # - paths indices (shape (num_samples,)) 294 indices = np.random.randint(0, cirs.shape[0], (g_cfg_num_samples,)) 295 dataset = tf.data.Dataset.from_tensor_slices((cirs[indices], np.arange(g_cfg_num_samples))) 296 assert dataset.cardinality() == g_cfg_num_samples 297 298 # Trace paths 299 print("Tracing paths... ", end="", flush=True) 300 paths_buffers = [] # Initialize the paths buffers list 301 tx_poss = g_cfg_tx_position_initial + indices[:, np.newaxis] * g_cfg_tx_velocity 302 for tx_pos in tx_poss: 303 g_scene._transmitters['tx'].position = tx_pos 304 paths_buffers.append(g_scene.trace_paths( 305 max_depth=g_cfg_max_depth, 306 num_samples=g_cfg_trace_paths_num_samples, 307 los=g_cfg_los, 308 reflection=g_cfg_specular_reflection, 309 diffraction=g_cfg_refraction, 310 scattering=g_cfg_diffuse_reflection, 311 ris=False, 312 edge_diffraction=False, 313 check_scene=False, 314 )) 315 print("Done") 316 317 # Fit the scene 318 prev_loss = tf.Variable(1e10, dtype=tf.float32) 319 omega = tf.random.uniform( 320 (2, 4, len(g_trainable_materials), g_cfg_num_embeddings), 321 minval=-1, maxval=1, dtype=tf.float32 322 ) 323 omega = tf.Variable(omega, dtype=tf.float32) 324 optimizer = tf.keras.optimizers.Adam( 325 learning_rate=g_cfg_learning_rate, 326 # beta_1=0.9, 327 # beta_2=0.999, 328 # epsilon=1e-8, 329 ) 330 331 print("Starting optimization...", flush=True) 332 loss_history = [] 333 omega_history = [] 334 dataset = dataset.batch(g_cfg_batch_size) 335 for cirs_batch, paths_indices_batch in dataset: 336 with tf.GradientTape() as tape: 337 curr_loss = objective(cirs_batch, paths_indices_batch, omega, paths_buffers) 338 grad = tape.gradient(curr_loss, omega) 339 optimizer.apply_gradients(zip([grad], [omega])) 340 341 if g_cfg_record_loss: 342 loss_history.append(curr_loss) 343 if g_cfg_record_omegas: 344 omega_history.append(omega.numpy().copy()) 345 346 # Check convergence: Stop if change in loss is below threshold 347 if abs(prev_loss - curr_loss) < g_cfg_convergence_threshold: 348 print("Converged") 349 break 350 prev_loss = curr_loss 351 print(f"Loss: {curr_loss}", flush=True) 352 353 print("Training completed") 354 355 # Save the losses and omegas 356 losses = np.asarray(loss_history) 357 omegas = np.asarray(omega_history) 358 if g_cfg_record_loss: 359 np.save("losses.npy", losses) 360 print("Saved losses to losses.npy: ", losses.shape) 361 if g_cfg_record_omegas: 362 omegas = omegas.reshape( 363 (-1, 2, 4, len(g_trainable_materials), g_cfg_num_embeddings) 364 ) 365 np.save("omegas.npy", omegas) 366 print("Saved omegas to omegas.npy: ", omegas.shape)