anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

cir_dataset.py (6145B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class for creating a CIR sampler, usuable as a channel model, from a CIR
      6     generator"""
      7 
      8 
      9 import tensorflow as tf
     10 
     11 from . import ChannelModel
     12 
     13 class CIRDataset(ChannelModel):
     14     # pylint: disable=line-too-long
     15     r"""CIRDataset(cir_generator, batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps, dtype=tf.complex64)
     16 
     17     Creates a channel model from a dataset that can be used with classes such as
     18     :class:`~sionna.channel.TimeChannel` and :class:`~sionna.channel.OFDMChannel`.
     19     The dataset is defined by a `generator <https://wiki.python.org/moin/Generators>`_.
     20 
     21     The batch size is configured when instantiating the dataset or through the :attr:`~sionna.channel.CIRDataset.batch_size` property.
     22     The number of time steps (`num_time_steps`) and sampling frequency (`sampling_frequency`) can only be set when instantiating the dataset.
     23     The specified values must be in accordance with the data.
     24 
     25     Example
     26     --------
     27 
     28     The following code snippet shows how to use this class as a channel model.
     29 
     30     >>> my_generator = MyGenerator(...)
     31     >>> channel_model = sionna.channel.CIRDataset(my_generator,
     32     ...                                           batch_size,
     33     ...                                           num_rx,
     34     ...                                           num_rx_ant,
     35     ...                                           num_tx,
     36     ...                                           num_tx_ant,
     37     ...                                           num_paths,
     38     ...                                           num_time_steps+l_tot-1)
     39     >>> channel = sionna.channel.TimeChannel(channel_model, bandwidth, num_time_steps)
     40 
     41     where ``MyGenerator`` is a generator
     42 
     43     >>> class MyGenerator:
     44     ...
     45     ...     def __call__(self):
     46     ...         ...
     47     ...         yield a, tau
     48 
     49     that returns complex-valued path coefficients ``a`` with shape
     50     `[num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps]`
     51     and real-valued path delays ``tau`` (in second)
     52     `[num_rx, num_tx, num_paths]`.
     53 
     54     Parameters
     55     ----------
     56     cir_generator : `generator <https://wiki.python.org/moin/Generators>`_
     57         Generator that returns channel impulse responses ``(a, tau)`` where
     58         ``a`` is the tensor of channel coefficients of shape
     59         `[num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps]`
     60         and dtype ``dtype``, and ``tau`` the tensor of path delays
     61         of shape  `[num_rx, num_tx, num_paths]` and dtype ``dtype.
     62         real_dtype``.
     63 
     64     batch_size : int
     65         Batch size
     66 
     67     num_rx : int
     68         Number of receivers (:math:`N_R`)
     69 
     70     num_rx_ant : int
     71         Number of antennas per receiver (:math:`N_{RA}`)
     72 
     73     num_tx : int
     74         Number of transmitters (:math:`N_T`)
     75 
     76     num_tx_ant : int
     77         Number of antennas per transmitter (:math:`N_{TA}`)
     78 
     79     num_paths : int
     80         Number of paths (:math:`M`)
     81 
     82     num_time_steps : int
     83         Number of time steps
     84 
     85     dtype : tf.DType
     86         Complex datatype to use for internal processing and output.
     87         Defaults to `tf.complex64`.
     88 
     89     Output
     90     -------
     91     a : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps], tf.complex
     92         Path coefficients
     93 
     94     tau : [batch size, num_rx, num_tx, num_paths], tf.float
     95         Path delays [s]
     96     """
     97 
     98     def __init__(self, cir_generator, batch_size, num_rx, num_rx_ant, num_tx,
     99         num_tx_ant, num_paths, num_time_steps, dtype=tf.complex64):
    100 
    101         self._cir_generator = cir_generator
    102         self._batch_size = batch_size
    103         self._num_time_steps = num_time_steps
    104 
    105         # TensorFlow dataset
    106         output_signature = (tf.TensorSpec(shape=[num_rx,
    107                                                  num_rx_ant,
    108                                                  num_tx,
    109                                                  num_tx_ant,
    110                                                  num_paths,
    111                                                  num_time_steps],
    112                                           dtype=dtype),
    113                             tf.TensorSpec(shape=[num_rx,
    114                                                  num_tx,
    115                                                  num_paths],
    116                                           dtype=dtype.real_dtype))
    117         dataset = tf.data.Dataset.from_generator(cir_generator,
    118                                             output_signature=output_signature)
    119         dataset = dataset.shuffle(32, reshuffle_each_iteration=True)
    120         self._dataset = dataset.repeat(None)
    121         self._batched_dataset = self._dataset.batch(batch_size)
    122         # Iterator for sampling the dataset
    123         self._iter = iter(self._batched_dataset)
    124 
    125     @property
    126     def batch_size(self):
    127         """Batch size"""
    128         return self._batch_size
    129 
    130     @batch_size.setter
    131     def batch_size(self, value):
    132         """Set the batch size"""
    133         self._batched_dataset = self._dataset.batch(value)
    134         self._iter = iter(self._batched_dataset)
    135         self._batch_size = value
    136 
    137     def __call__(self, batch_size=None,
    138                        num_time_steps=None,
    139                        sampling_frequency=None):
    140 
    141 #         if ( (batch_size is not None)
    142 #                 and tf.not_equal(batch_size, self._batch_size) ):
    143 #             tf.print("Warning: The value of `batch_size` specified when calling \
    144 # the CIRDataset is different from the one configured for the dataset. \
    145 # The value specified when calling is ignored. Use the `batch_size` property \
    146 # of CIRDataset to use a batch size different from the one set when \
    147 # instantiating.")
    148 
    149 #         if ( (num_time_steps is not None)
    150 #             and tf.not_equal(num_time_steps, self._num_time_steps) ):
    151 #             tf.print("Warning: The value of `num_time_steps` specified when \
    152 # calling the CIRDataset is different from the one speficied when instantiating \
    153 # the dataset. The value specified when calling is ignored.")
    154 
    155         return next(self._iter)