anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

tf_paths.py (13281B)


      1 # flake8: noqa: E501  # Ignore long lines
      2 
      3 # Implementation following the
      4 # "Learning Radio Environments by Differentiable Ray Tracing"
      5 # 2024 NVIDIA paper
      6 # http://dx.doi.org/10.1109/TMLCN.2024.3474639
      7 
      8 # What this script does:
      9 # - Loads the scene and measurements
     10 # - Merges seconds and milliseconds (thus getting a total milliseconds dimension, further referred to as "samples")
     11 # - Traces paths for each sample (picking random g_num_samples samples) with sionna rt Candidatecalculator and ImageMethod, generating a dataset of g_num_samples PathBuffers
     12 # - Trains the scene material parameters according to Algorithm 1 in the paper
     13 # - Saves the losses and the trainable parameters (omega) to files (losses.npy and omegas.npy)
     14 
     15 from measurements import load as measurements_load
     16 
     17 import numpy as np
     18 import tensorflow as tf
     19 import sionna.rt as srt
     20 from sionna.constants import PI
     21 
     22 from typing import Tuple, List
     23 import os
     24 
     25 # Config
     26 # Dataset
     27 g_cfg_meas_round = 11  # Measurements round to load
     28 g_cfg_num_samples = 10  # Number of samples to use from the dataset. Picked randomly.
     29 g_cfg_batch_size = 3  # Number of samples in each batch (B)
     30 # Fit
     31 g_cfg_num_embeddings = 2  # L in the paper
     32 g_cfg_learning_rate = 0.01  # beta in the paper
     33 g_cfg_decay = tf.Variable(0.001, dtype=tf.float32)  # delta in the paper
     34 g_cfg_max_iterations = 5  # max number of iterations for the optimization
     35 g_cfg_convergence_threshold = 1e-6
     36 g_cfg_eps = 1e-6  # For numerical gradient calculation
     37 g_cfg_keep_material_params = False  # If True, start fitting from the current material parameters
     38 # RF
     39 g_cfg_central_frequency = 3.75e9  # Set into scene.frequency
     40 g_cfg_subcarrier_spacing = 100e6  # Delta_f in the paper
     41 g_cfg_sampling_rate = 1. / 1000.
     42 # Spacial
     43 g_cfg_rx_locations = [
     44     [4, 5, 2.2],
     45     [4, 5, 2.2],
     46     [20.5, 8, 3],
     47     [20.5, 8, 3],
     48     [22, 0.5, 2.85],
     49     [22, 0.5, 2.85]
     50 ]
     51 g_cfg_rx_orientations = [
     52     [0, 0, -PI/2],
     53     [0, 0, 0],
     54     [0, 0, -PI/2],
     55     [0, 0, 0],
     56     [0, 0, PI/2],
     57     [0, 0, 0]
     58 ]
     59 g_cfg_tx_position_initial = np.array([5.5, 5, 0.5])  # Initial TX position
     60 g_cfg_tx_velocity = np.array([0.0005, 0, 0])  # meters per millisecond
     61 # Sionna
     62 g_cfg_max_depth = 3
     63 g_cfg_trace_paths_num_samples = 1000  # Number of samples to trace paths for
     64 g_cfg_synthetic_array = False
     65 g_cfg_los = True
     66 g_cfg_specular_reflection = True
     67 g_cfg_diffuse_reflection = True
     68 g_cfg_refraction = True
     69 # Output
     70 g_cfg_record_loss = True
     71 g_cfg_record_omegas = True
     72 
     73 # Globals
     74 g_scene: srt.Scene
     75 g_trainable_materials: List[srt.RadioMaterial] = []  # List of trainable materials in the scene
     76 g_bandwidth: float  # Bandwidth of the signal (W in the paper)
     77 
     78 
     79 @tf.function
     80 def calc_material_parameters(omega: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
     81     """Calculate the material parameters from the read-out vectors and embeddings.
     82 
     83     Args:
     84         omega: Read-out and embeddings tensor. Shape (2, 4, num_materials, num_embeddings)
     85 
     86     Returns:
     87         Tuple containing:
     88             - cond: Conductivity of the materials.
     89             - perm: Relative permittivity of the materials.
     90             - scat: Scattering coefficient of the materials.
     91             - xpd: XPD coefficient of the materials.
     92     """
     93     dot = tf.reduce_sum(omega[0] * omega[1], axis=-1)
     94     cond = tf.exp(dot[0])
     95     perm = tf.exp(dot[1]) + 1.
     96     scat = tf.sigmoid(dot[2])
     97     xpd = tf.sigmoid(dot[3])
     98     return cond, perm, scat, xpd
     99 
    100 
    101 @tf.function
    102 def calc_omega(
    103     cond: tf.Tensor,
    104     perm: tf.Tensor,
    105     scat: tf.Tensor,
    106     xpd: tf.Tensor,
    107 ) -> tf.Tensor:
    108     """Calculate the omega vector from the material parameters.
    109 
    110     Args:
    111         cond: Conductivity of the materials.
    112         perm: Relative permittivity of the materials.
    113         scat: Scattering coefficient of the materials.
    114         xpd: XPD coefficient of the materials.
    115 
    116     Returns:
    117         The omega vector (shape: (2, 4, num_materials, num_embeddings,)).
    118     """
    119     num_materials = tf.shape(cond)[0]
    120     dot0 = tf.math.log(cond)
    121     dot1 = tf.math.log(perm - 1.)
    122     dot2 = -tf.math.log(1 / scat - 1)
    123     dot3 = -tf.math.log(1 / xpd - 1)
    124     dot = tf.stack([dot0, dot1, dot2, dot3], axis=0)
    125     w = tf.random.uniform((4, num_materials, g_cfg_num_embeddings), minval=-1, maxval=1, dtype=tf.float32)
    126     v = dot[:, :, tf.newaxis] / w
    127     return tf.stack([v, w], axis=0)
    128 
    129 
    130 @tf.function
    131 def set_omega(omega: tf.Tensor) -> None:
    132     """Set the trainable parameters in the scene from omega."""
    133     cond, perm, scat, xpd = calc_material_parameters(omega)
    134     for idx, mat in enumerate(g_trainable_materials):
    135         g_scene._radio_materials[mat.name].conductivity = tf.gather(cond, idx)
    136         g_scene._radio_materials[mat.name].relative_permittivity = tf.gather(perm, idx)
    137         g_scene._radio_materials[mat.name].scattering_coefficient = tf.gather(scat, idx)
    138         g_scene._radio_materials[mat.name].xpd_coefficient = tf.gather(xpd, idx)
    139 
    140 
    141 g_alpha = tf.Variable(6e-9, dtype=tf.float32)  # Initial α_i
    142 @tf.function
    143 def objective(cirs_batch, batch_indices, omega, paths_buffers) -> tf.Variable:
    144     """Compute the objective function for the optimization.
    145     Returns the average loss over the batch as in (38) of the paper.
    146     """
    147     set_omega(omega)
    148     p_hat = tf.TensorArray(dtype=tf.float32, size=tf.shape(batch_indices)[0])
    149     tau_hat = tf.TensorArray(dtype=tf.float32, size=tf.shape(batch_indices)[0])
    150     for i in tf.range(tf.shape(batch_indices)[0]):
    151         ind = batch_indices[i]
    152         def get_a_tau(index):
    153             return g_scene.compute_fields(
    154                 *paths_buffers[int(index)],
    155                 check_scene=False,
    156                 scat_random_phases=False,
    157                 testing=False,
    158             ).cir(
    159                 los=g_cfg_los,
    160                 reflection=g_cfg_specular_reflection,
    161                 diffraction=g_cfg_refraction,
    162                 scattering=g_cfg_diffuse_reflection,
    163                 ris=False,
    164                 cluster_ris_paths=False,
    165                 num_paths=None,
    166             )
    167         a, tau = tf.py_function(
    168             get_a_tau, 
    169             [ind], 
    170             Tout=[tf.complex64, tf.float32]
    171         )
    172         # Record
    173         p_val = tf.reduce_sum(tf.abs(a) ** 2)
    174         tau_val = tf.reduce_mean(tau)
    175         p_hat = p_hat.write(i, p_val)
    176         tau_hat = tau_hat.write(i, tau_val)
    177     p_hat = p_hat.stack()
    178     tau_hat = tau_hat.stack()
    179 
    180     # Scale the batch: h_b ← √α_i * h_b
    181     p = tf.reduce_sum(tf.abs(cirs_batch) ** 2, axis=(0, 1, 2))
    182     alpha_hat = tf.reduce_sum(p[..., tf.newaxis] * p_hat) / (tf.reduce_sum(p_hat ** 2) + 1e-8)
    183     alpha = g_cfg_decay * g_alpha + (1 - g_cfg_decay) * alpha_hat
    184     alpha = tf.cast(alpha, tf.complex64)
    185     cirs_scaled = tf.sqrt(alpha) * cirs_batch
    186     p = tf.reduce_sum(tf.abs(cirs_scaled) ** 2, axis=(0, 1, 2))
    187 
    188     # Approximate tau_RMS as in (40, 41)
    189     total_channel_gain = tf.reduce_sum(p)
    190     p = p / (total_channel_gain + 1e-8)
    191     ls = np.arange(512) - 512 // 2 # Subcarrier indices
    192     tau_overline = tf.reduce_sum(ls[:, None] * p, axis=0)
    193     tau_rms = (ls - tau_overline) / g_bandwidth
    194     tau_rms = tf.sqrt(tf.reduce_sum((tau_rms ** 2)[:, None] * p))
    195 
    196     # Compute sample loss L_b as in (38)
    197     def smape(x, y):
    198         return tf.abs(x - y) / (tf.abs(x) + tf.abs(y) + 1e-8)
    199     losses = tf.reduce_mean(smape(p[..., tf.newaxis], p_hat), axis=0) + smape(tau_rms, tau_hat)
    200     # Average loss over the batch: L = (1/B) * sum(L_b)
    201     return tf.reduce_mean(losses)
    202 
    203 
    204 if __name__ == "__main__":
    205     # Initialization
    206     # Scene
    207     print("Loading scene... ", end='', flush=True)
    208     g_scene = srt.load_scene(
    209         os.path.join(
    210             os.path.dirname(__file__),
    211             "..",
    212             "scene",
    213             "scene.xml")
    214     )
    215     g_scene.frequency = g_cfg_central_frequency
    216     g_scene.tx_array = srt.PlanarArray(
    217         num_rows=1,
    218         num_cols=1,
    219         vertical_spacing=0.5,
    220         horizontal_spacing=0.5,
    221         pattern="dipole",
    222         polarization="V",
    223     )
    224     g_scene.rx_array = srt.PlanarArray(
    225         num_rows=1,
    226         num_cols=1,
    227         vertical_spacing=0.5,
    228         horizontal_spacing=0.5,
    229         pattern="dipole",
    230         polarization="VH",
    231     )
    232     for rx_idx, rx_loc in enumerate(g_cfg_rx_locations):
    233         rx = srt.Receiver(
    234             name=f"rx{rx_idx}",
    235             position=rx_loc,
    236             orientation=g_cfg_rx_orientations[rx_idx],
    237         )
    238         g_scene.add(rx)
    239     tx = srt.Transmitter(
    240         name="tx",
    241         position=g_cfg_tx_position_initial,
    242     )
    243     g_scene.add(tx)
    244     print("Done")
    245 
    246     # Make trainable materials
    247     for mat in g_scene.radio_materials.values():
    248         if not mat.is_used:
    249             continue
    250         # Create new trainable material
    251         if g_cfg_keep_material_params:
    252             # Start from the current material parameters
    253             new_mat = srt.RadioMaterial(
    254                 mat.name + "_train",
    255                 relative_permittivity=mat.relative_permittivity,
    256                 conductivity=mat.conductivity,
    257                 scattering_coefficient=mat.scattering_coefficient,
    258                 xpd_coefficient=mat.xpd_coefficient
    259             )
    260         else:
    261             new_mat = srt.RadioMaterial(mat.name + "_train",
    262                                         relative_permittivity=3.0,
    263                                         conductivity=0.1)
    264         g_scene.add(new_mat)
    265         g_trainable_materials.append(new_mat)
    266 
    267     # Assign trainable materials to the corresponding objects
    268     for obj in g_scene.objects.values():
    269         obj.radio_material = obj.radio_material.name + "_train"
    270 
    271     # Load the measurements
    272     print("Loading measurements... ", end='', flush=True)
    273     data_dirpath = os.path.join(
    274         os.path.dirname(__file__),
    275         "..", "..", "data"
    276     )
    277     # [num_aps, num_rfs, num_seconds, num_subcarriers, 1000]
    278     cirs = measurements_load(data_dirpath, g_cfg_meas_round)
    279     # Cast to complex64
    280     cirs = cirs.astype(np.complex64)
    281     # Move num_seconds to 0 and 1000 to 1
    282     # [num_seconds, 1000, num_aps, num_rfs, num_subcarriers]
    283     cirs = cirs.transpose((2, 4, 0, 1, 3))
    284     # Unite num_seconds and 1000
    285     # [num_seconds * 1000, num_aps, num_rfs, num_subcarriers]
    286     cirs = cirs.reshape((cirs.shape[0] * cirs.shape[1], cirs.shape[2], cirs.shape[3], cirs.shape[4]))
    287     g_bandwidth = cirs.shape[-2] * g_cfg_subcarrier_spacing  # W in the paper
    288     print("Done")
    289 
    290     # Make the dataset
    291     # Contains:
    292     # - CIRs (shape (num_samples, num_aps, num_rfs, num_subcarriers))
    293     # - paths indices (shape (num_samples,))
    294     indices = np.random.randint(0, cirs.shape[0], (g_cfg_num_samples,))
    295     dataset = tf.data.Dataset.from_tensor_slices((cirs[indices], np.arange(g_cfg_num_samples)))
    296     assert dataset.cardinality() == g_cfg_num_samples
    297 
    298     # Trace paths
    299     print("Tracing paths... ", end="", flush=True)
    300     paths_buffers = []  # Initialize the paths buffers list
    301     tx_poss = g_cfg_tx_position_initial + indices[:, np.newaxis] * g_cfg_tx_velocity
    302     for tx_pos in tx_poss:
    303         g_scene._transmitters['tx'].position = tx_pos
    304         paths_buffers.append(g_scene.trace_paths(
    305             max_depth=g_cfg_max_depth,
    306             num_samples=g_cfg_trace_paths_num_samples,
    307             los=g_cfg_los,
    308             reflection=g_cfg_specular_reflection,
    309             diffraction=g_cfg_refraction,
    310             scattering=g_cfg_diffuse_reflection,
    311             ris=False,
    312             edge_diffraction=False,
    313             check_scene=False,
    314         ))
    315     print("Done")
    316 
    317     # Fit the scene
    318     prev_loss = tf.Variable(1e10, dtype=tf.float32)
    319     omega = tf.random.uniform(
    320         (2, 4, len(g_trainable_materials), g_cfg_num_embeddings),
    321         minval=-1, maxval=1, dtype=tf.float32
    322     )
    323     omega = tf.Variable(omega, dtype=tf.float32)
    324     optimizer = tf.keras.optimizers.Adam(
    325         learning_rate=g_cfg_learning_rate,
    326         # beta_1=0.9,
    327         # beta_2=0.999,
    328         # epsilon=1e-8,
    329     )
    330 
    331     print("Starting optimization...", flush=True)
    332     loss_history = []
    333     omega_history = []
    334     dataset = dataset.batch(g_cfg_batch_size)
    335     for cirs_batch, paths_indices_batch in dataset:
    336         with tf.GradientTape() as tape:
    337             curr_loss = objective(cirs_batch, paths_indices_batch, omega, paths_buffers)
    338         grad = tape.gradient(curr_loss, omega)
    339         optimizer.apply_gradients(zip([grad], [omega]))
    340 
    341         if g_cfg_record_loss:
    342             loss_history.append(curr_loss)
    343         if g_cfg_record_omegas:
    344             omega_history.append(omega.numpy().copy())
    345 
    346         # Check convergence: Stop if change in loss is below threshold
    347         if abs(prev_loss - curr_loss) < g_cfg_convergence_threshold:
    348             print("Converged")
    349             break
    350         prev_loss = curr_loss
    351         print(f"Loss: {curr_loss}", flush=True)
    352 
    353     print("Training completed")
    354 
    355     # Save the losses and omegas
    356     losses = np.asarray(loss_history)
    357     omegas = np.asarray(omega_history)
    358     if g_cfg_record_loss:
    359         np.save("losses.npy", losses)
    360         print("Saved losses to losses.npy: ", losses.shape)
    361     if g_cfg_record_omegas:
    362         omegas = omegas.reshape(
    363             (-1, 2, 4, len(g_trainable_materials), g_cfg_num_embeddings)
    364         )
    365         np.save("omegas.npy", omegas)
    366         print("Saved omegas to omegas.npy: ", omegas.shape)