flat_fading_channel.py (8588B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Classes for the simulation of flat-fading channels""" 6 7 import tensorflow as tf 8 from sionna.channel import AWGN 9 from sionna.utils import complex_normal 10 11 class GenerateFlatFadingChannel(): 12 # pylint: disable=line-too-long 13 r"""Generates tensors of flat-fading channel realizations. 14 15 This class generates batches of random flat-fading channel matrices. 16 A spatial correlation can be applied. 17 18 Parameters 19 ---------- 20 num_tx_ant : int 21 Number of transmit antennas. 22 23 num_rx_ant : int 24 Number of receive antennas. 25 26 spatial_corr : SpatialCorrelation, None 27 An instance of :class:`~sionna.channel.SpatialCorrelation` or `None`. 28 Defaults to `None`. 29 30 dtype : tf.complex64, tf.complex128 31 The dtype of the output. Defaults to `tf.complex64`. 32 33 Input 34 ----- 35 batch_size : int 36 The batch size, i.e., the number of channel matrices to generate. 37 38 Output 39 ------ 40 h : [batch_size, num_rx_ant, num_tx_ant], ``dtype`` 41 Batch of random flat fading channel matrices. 42 43 """ 44 def __init__(self, num_tx_ant, num_rx_ant, spatial_corr=None, dtype=tf.complex64, **kwargs): 45 super().__init__(**kwargs) 46 self._num_tx_ant = num_tx_ant 47 self._num_rx_ant = num_rx_ant 48 self._dtype = dtype 49 self.spatial_corr = spatial_corr 50 51 @property 52 def spatial_corr(self): 53 """The :class:`~sionna.channel.SpatialCorrelation` to be used.""" 54 return self._spatial_corr 55 56 @spatial_corr.setter 57 def spatial_corr(self, value): 58 self._spatial_corr = value 59 60 def __call__(self, batch_size): 61 # Generate standard complex Gaussian matrices 62 shape = [batch_size, self._num_rx_ant, self._num_tx_ant] 63 h = complex_normal(shape, dtype=self._dtype) 64 65 # Apply spatial correlation 66 if self.spatial_corr is not None: 67 h = self.spatial_corr(h) 68 69 return h 70 71 class ApplyFlatFadingChannel(tf.keras.layers.Layer): 72 # pylint: disable=line-too-long 73 r"""ApplyFlatFadingChannel(add_awgn=True, dtype=tf.complex64, **kwargs) 74 75 Applies given channel matrices to a vector input and adds AWGN. 76 77 This class applies a given tensor of flat-fading channel matrices 78 to an input tensor. AWGN noise can be optionally added. 79 Mathematically, for channel matrices 80 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` 81 and input :math:`\mathbf{x}\in\mathbb{C}^{K}`, the output is 82 83 .. math:: 84 85 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 86 87 where :math:`\mathbf{n}\in\mathbb{C}^{M}\sim\mathcal{CN}(0, N_o\mathbf{I})` 88 is an AWGN vector that is optionally added. 89 90 91 Parameters 92 ---------- 93 add_awgn: bool 94 Indicates if AWGN noise should be added to the output. 95 Defaults to `True`. 96 97 dtype : tf.complex64, tf.complex128 98 The dtype of the output. Defaults to `tf.complex64`. 99 100 Input 101 ----- 102 (x, h, no) : 103 Tuple: 104 105 x : [batch_size, num_tx_ant], tf.complex 106 Tensor of transmit vectors. 107 108 h : [batch_size, num_rx_ant, num_tx_ant], tf.complex 109 Tensor of channel realizations. Will be broadcast to the 110 dimensions of ``x`` if needed. 111 112 no : Scalar or Tensor, tf.float 113 The noise power ``no`` is per complex dimension. 114 Only required if ``add_awgn==True``. 115 Will be broadcast to the shape of ``y``. 116 For more details, see :class:`~sionna.channel.AWGN`. 117 118 Output 119 ------ 120 y : [batch_size, num_rx_ant], ``dtype`` 121 Channel output. 122 """ 123 def __init__(self, add_awgn=True, dtype=tf.complex64, **kwargs): 124 super().__init__(trainable=False, dtype=dtype, **kwargs) 125 self._add_awgn = add_awgn 126 127 def build(self, input_shape): #pylint: disable=unused-argument 128 if self._add_awgn: 129 self._awgn = AWGN(dtype=self.dtype) 130 131 def call(self, inputs): 132 if self._add_awgn: 133 x, h, no = inputs 134 else: 135 x, h = inputs 136 137 x = tf.expand_dims(x, axis=-1) 138 y = tf.matmul(h, x) 139 y = tf.squeeze(y, axis=-1) 140 141 if self._add_awgn: 142 y = self._awgn((y, no)) 143 144 return y 145 146 class FlatFadingChannel(tf.keras.layers.Layer): 147 # pylint: disable=line-too-long 148 r"""FlatFadingChannel(num_tx_ant, num_rx_ant, spatial_corr=None, add_awgn=True, return_channel=False, dtype=tf.complex64, **kwargs) 149 150 Applies random channel matrices to a vector input and adds AWGN. 151 152 This class combines :class:`~sionna.channel.GenerateFlatFadingChannel` and 153 :class:`~sionna.channel.ApplyFlatFadingChannel` and computes the output of 154 a flat-fading channel with AWGN. 155 156 For a given batch of input vectors :math:`\mathbf{x}\in\mathbb{C}^{K}`, 157 the output is 158 159 .. math:: 160 161 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 162 163 where :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` are randomly generated 164 flat-fading channel matrices and 165 :math:`\mathbf{n}\in\mathbb{C}^{M}\sim\mathcal{CN}(0, N_o\mathbf{I})` 166 is an AWGN vector that is optionally added. 167 168 A :class:`~sionna.channel.SpatialCorrelation` can be configured and the 169 channel realizations optionally returned. This is useful to simulate 170 receiver algorithms with perfect channel knowledge. 171 172 Parameters 173 ---------- 174 num_tx_ant : int 175 Number of transmit antennas. 176 177 num_rx_ant : int 178 Number of receive antennas. 179 180 spatial_corr : SpatialCorrelation, None 181 An instance of :class:`~sionna.channel.SpatialCorrelation` or `None`. 182 Defaults to `None`. 183 184 add_awgn: bool 185 Indicates if AWGN noise should be added to the output. 186 Defaults to `True`. 187 188 return_channel: bool 189 Indicates if the channel realizations should be returned. 190 Defaults to `False`. 191 192 dtype : tf.complex64, tf.complex128 193 The dtype of the output. Defaults to `tf.complex64`. 194 195 Input 196 ----- 197 (x, no) : 198 Tuple or Tensor: 199 200 x : [batch_size, num_tx_ant], tf.complex 201 Tensor of transmit vectors. 202 203 no : Scalar of Tensor, tf.float 204 The noise power ``no`` is per complex dimension. 205 Only required if ``add_awgn==True``. 206 Will be broadcast to the dimensions of the channel output if needed. 207 For more details, see :class:`~sionna.channel.AWGN`. 208 209 Output 210 ------ 211 (y, h) : 212 Tuple or Tensor: 213 214 y : [batch_size, num_rx_ant], ``dtype`` 215 Channel output. 216 217 h : [batch_size, num_rx_ant, num_tx_ant], ``dtype`` 218 Channel realizations. Will only be returned if 219 ``return_channel==True``. 220 """ 221 def __init__(self, 222 num_tx_ant, 223 num_rx_ant, 224 spatial_corr=None, 225 add_awgn=True, 226 return_channel=False, 227 dtype=tf.complex64, 228 **kwargs): 229 super().__init__(trainable=False, dtype=dtype, **kwargs) 230 self._num_tx_ant = num_tx_ant 231 self._num_rx_ant = num_rx_ant 232 self._add_awgn = add_awgn 233 self._return_channel = return_channel 234 self._gen_chn = GenerateFlatFadingChannel(self._num_tx_ant, 235 self._num_rx_ant, 236 spatial_corr, 237 dtype=dtype) 238 self._app_chn = ApplyFlatFadingChannel(add_awgn=add_awgn, dtype=dtype) 239 240 @property 241 def spatial_corr(self): 242 """The :class:`~sionna.channel.SpatialCorrelation` to be used.""" 243 return self._gen_chn.spatial_corr 244 245 @spatial_corr.setter 246 def spatial_corr(self, value): 247 self._gen_chn.spatial_corr = value 248 249 @property 250 def generate(self): 251 """Calls the internal :class:`GenerateFlatFadingChannel`.""" 252 return self._gen_chn 253 254 @property 255 def apply(self): 256 """Calls the internal :class:`ApplyFlatFadingChannel`.""" 257 return self._app_chn 258 259 def call(self, inputs): 260 if self._add_awgn: 261 x, no = inputs 262 else: 263 x = inputs 264 265 # Generate a batch of channel realizations 266 batch_size = tf.shape(x)[0] 267 h = self._gen_chn(batch_size) 268 269 # Apply the channel to the input 270 if self._add_awgn: 271 y = self._app_chn([x, h, no]) 272 else: 273 y = self._app_chn([x, h]) 274 275 if self._return_channel: 276 return y, h 277 else: 278 return y