cir_dataset.py (6145B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class for creating a CIR sampler, usuable as a channel model, from a CIR 6 generator""" 7 8 9 import tensorflow as tf 10 11 from . import ChannelModel 12 13 class CIRDataset(ChannelModel): 14 # pylint: disable=line-too-long 15 r"""CIRDataset(cir_generator, batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps, dtype=tf.complex64) 16 17 Creates a channel model from a dataset that can be used with classes such as 18 :class:`~sionna.channel.TimeChannel` and :class:`~sionna.channel.OFDMChannel`. 19 The dataset is defined by a `generator <https://wiki.python.org/moin/Generators>`_. 20 21 The batch size is configured when instantiating the dataset or through the :attr:`~sionna.channel.CIRDataset.batch_size` property. 22 The number of time steps (`num_time_steps`) and sampling frequency (`sampling_frequency`) can only be set when instantiating the dataset. 23 The specified values must be in accordance with the data. 24 25 Example 26 -------- 27 28 The following code snippet shows how to use this class as a channel model. 29 30 >>> my_generator = MyGenerator(...) 31 >>> channel_model = sionna.channel.CIRDataset(my_generator, 32 ... batch_size, 33 ... num_rx, 34 ... num_rx_ant, 35 ... num_tx, 36 ... num_tx_ant, 37 ... num_paths, 38 ... num_time_steps+l_tot-1) 39 >>> channel = sionna.channel.TimeChannel(channel_model, bandwidth, num_time_steps) 40 41 where ``MyGenerator`` is a generator 42 43 >>> class MyGenerator: 44 ... 45 ... def __call__(self): 46 ... ... 47 ... yield a, tau 48 49 that returns complex-valued path coefficients ``a`` with shape 50 `[num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps]` 51 and real-valued path delays ``tau`` (in second) 52 `[num_rx, num_tx, num_paths]`. 53 54 Parameters 55 ---------- 56 cir_generator : `generator <https://wiki.python.org/moin/Generators>`_ 57 Generator that returns channel impulse responses ``(a, tau)`` where 58 ``a`` is the tensor of channel coefficients of shape 59 `[num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps]` 60 and dtype ``dtype``, and ``tau`` the tensor of path delays 61 of shape `[num_rx, num_tx, num_paths]` and dtype ``dtype. 62 real_dtype``. 63 64 batch_size : int 65 Batch size 66 67 num_rx : int 68 Number of receivers (:math:`N_R`) 69 70 num_rx_ant : int 71 Number of antennas per receiver (:math:`N_{RA}`) 72 73 num_tx : int 74 Number of transmitters (:math:`N_T`) 75 76 num_tx_ant : int 77 Number of antennas per transmitter (:math:`N_{TA}`) 78 79 num_paths : int 80 Number of paths (:math:`M`) 81 82 num_time_steps : int 83 Number of time steps 84 85 dtype : tf.DType 86 Complex datatype to use for internal processing and output. 87 Defaults to `tf.complex64`. 88 89 Output 90 ------- 91 a : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time_steps], tf.complex 92 Path coefficients 93 94 tau : [batch size, num_rx, num_tx, num_paths], tf.float 95 Path delays [s] 96 """ 97 98 def __init__(self, cir_generator, batch_size, num_rx, num_rx_ant, num_tx, 99 num_tx_ant, num_paths, num_time_steps, dtype=tf.complex64): 100 101 self._cir_generator = cir_generator 102 self._batch_size = batch_size 103 self._num_time_steps = num_time_steps 104 105 # TensorFlow dataset 106 output_signature = (tf.TensorSpec(shape=[num_rx, 107 num_rx_ant, 108 num_tx, 109 num_tx_ant, 110 num_paths, 111 num_time_steps], 112 dtype=dtype), 113 tf.TensorSpec(shape=[num_rx, 114 num_tx, 115 num_paths], 116 dtype=dtype.real_dtype)) 117 dataset = tf.data.Dataset.from_generator(cir_generator, 118 output_signature=output_signature) 119 dataset = dataset.shuffle(32, reshuffle_each_iteration=True) 120 self._dataset = dataset.repeat(None) 121 self._batched_dataset = self._dataset.batch(batch_size) 122 # Iterator for sampling the dataset 123 self._iter = iter(self._batched_dataset) 124 125 @property 126 def batch_size(self): 127 """Batch size""" 128 return self._batch_size 129 130 @batch_size.setter 131 def batch_size(self, value): 132 """Set the batch size""" 133 self._batched_dataset = self._dataset.batch(value) 134 self._iter = iter(self._batched_dataset) 135 self._batch_size = value 136 137 def __call__(self, batch_size=None, 138 num_time_steps=None, 139 sampling_frequency=None): 140 141 # if ( (batch_size is not None) 142 # and tf.not_equal(batch_size, self._batch_size) ): 143 # tf.print("Warning: The value of `batch_size` specified when calling \ 144 # the CIRDataset is different from the one configured for the dataset. \ 145 # The value specified when calling is ignored. Use the `batch_size` property \ 146 # of CIRDataset to use a batch size different from the one set when \ 147 # instantiating.") 148 149 # if ( (num_time_steps is not None) 150 # and tf.not_equal(num_time_steps, self._num_time_steps) ): 151 # tf.print("Warning: The value of `num_time_steps` specified when \ 152 # calling the CIRDataset is different from the one speficied when instantiating \ 153 # the dataset. The value specified when calling is ignored.") 154 155 return next(self._iter)