anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

flat_fading_channel.py (8588B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Classes for the simulation of flat-fading channels"""
      6 
      7 import tensorflow as tf
      8 from sionna.channel import AWGN
      9 from sionna.utils import complex_normal
     10 
     11 class GenerateFlatFadingChannel():
     12     # pylint: disable=line-too-long
     13     r"""Generates tensors of flat-fading channel realizations.
     14 
     15     This class generates batches of random flat-fading channel matrices.
     16     A spatial correlation can be applied.
     17 
     18     Parameters
     19     ----------
     20     num_tx_ant : int
     21         Number of transmit antennas.
     22 
     23     num_rx_ant : int
     24         Number of receive antennas.
     25 
     26     spatial_corr : SpatialCorrelation, None
     27         An instance of :class:`~sionna.channel.SpatialCorrelation` or `None`.
     28         Defaults to `None`.
     29 
     30     dtype : tf.complex64, tf.complex128
     31         The dtype of the output. Defaults to `tf.complex64`.
     32 
     33     Input
     34     -----
     35     batch_size : int
     36         The batch size, i.e., the number of channel matrices to generate.
     37 
     38     Output
     39     ------
     40     h : [batch_size, num_rx_ant, num_tx_ant], ``dtype``
     41         Batch of random flat fading channel matrices.
     42 
     43     """
     44     def __init__(self, num_tx_ant, num_rx_ant, spatial_corr=None, dtype=tf.complex64, **kwargs):
     45         super().__init__(**kwargs)
     46         self._num_tx_ant = num_tx_ant
     47         self._num_rx_ant = num_rx_ant
     48         self._dtype = dtype
     49         self.spatial_corr = spatial_corr
     50 
     51     @property
     52     def spatial_corr(self):
     53         """The :class:`~sionna.channel.SpatialCorrelation` to be used."""
     54         return self._spatial_corr
     55 
     56     @spatial_corr.setter
     57     def spatial_corr(self, value):
     58         self._spatial_corr = value
     59 
     60     def __call__(self, batch_size):
     61         # Generate standard complex Gaussian matrices
     62         shape = [batch_size, self._num_rx_ant, self._num_tx_ant]
     63         h = complex_normal(shape, dtype=self._dtype)
     64 
     65         # Apply spatial correlation
     66         if self.spatial_corr is not None:
     67             h = self.spatial_corr(h)
     68 
     69         return h
     70 
     71 class ApplyFlatFadingChannel(tf.keras.layers.Layer):
     72     # pylint: disable=line-too-long
     73     r"""ApplyFlatFadingChannel(add_awgn=True, dtype=tf.complex64, **kwargs)
     74 
     75     Applies given channel matrices to a vector input and adds AWGN.
     76 
     77     This class applies a given tensor of flat-fading channel matrices
     78     to an input tensor. AWGN noise can be optionally added.
     79     Mathematically, for channel matrices
     80     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}`
     81     and input :math:`\mathbf{x}\in\mathbb{C}^{K}`, the output is
     82 
     83     .. math::
     84 
     85         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
     86 
     87     where :math:`\mathbf{n}\in\mathbb{C}^{M}\sim\mathcal{CN}(0, N_o\mathbf{I})`
     88     is an AWGN vector that is optionally added.
     89 
     90 
     91     Parameters
     92     ----------
     93     add_awgn: bool
     94         Indicates if AWGN noise should be added to the output.
     95         Defaults to `True`.
     96 
     97     dtype : tf.complex64, tf.complex128
     98         The dtype of the output. Defaults to `tf.complex64`.
     99 
    100     Input
    101     -----
    102     (x, h, no) :
    103         Tuple:
    104 
    105     x : [batch_size, num_tx_ant], tf.complex
    106         Tensor of transmit vectors.
    107 
    108     h : [batch_size, num_rx_ant, num_tx_ant], tf.complex
    109         Tensor of channel realizations. Will be broadcast to the
    110         dimensions of ``x`` if needed.
    111 
    112     no : Scalar or Tensor, tf.float
    113         The noise power ``no`` is per complex dimension.
    114         Only required if ``add_awgn==True``.
    115         Will be broadcast to the shape of ``y``.
    116         For more details, see :class:`~sionna.channel.AWGN`.
    117 
    118     Output
    119     ------
    120     y : [batch_size, num_rx_ant], ``dtype``
    121         Channel output.
    122     """
    123     def __init__(self, add_awgn=True, dtype=tf.complex64, **kwargs):
    124         super().__init__(trainable=False, dtype=dtype, **kwargs)
    125         self._add_awgn = add_awgn
    126 
    127     def build(self, input_shape): #pylint: disable=unused-argument
    128         if self._add_awgn:
    129             self._awgn = AWGN(dtype=self.dtype)
    130 
    131     def call(self, inputs):
    132         if self._add_awgn:
    133             x, h, no = inputs
    134         else:
    135             x, h = inputs
    136 
    137         x = tf.expand_dims(x, axis=-1)
    138         y = tf.matmul(h, x)
    139         y = tf.squeeze(y, axis=-1)
    140 
    141         if self._add_awgn:
    142             y = self._awgn((y, no))
    143 
    144         return y
    145 
    146 class FlatFadingChannel(tf.keras.layers.Layer):
    147     # pylint: disable=line-too-long
    148     r"""FlatFadingChannel(num_tx_ant, num_rx_ant, spatial_corr=None, add_awgn=True, return_channel=False, dtype=tf.complex64, **kwargs)
    149 
    150     Applies random channel matrices to a vector input and adds AWGN.
    151 
    152     This class combines :class:`~sionna.channel.GenerateFlatFadingChannel` and
    153     :class:`~sionna.channel.ApplyFlatFadingChannel` and computes the output of
    154     a flat-fading channel with AWGN.
    155 
    156     For a given batch of input vectors :math:`\mathbf{x}\in\mathbb{C}^{K}`,
    157     the output is
    158 
    159     .. math::
    160 
    161         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    162 
    163     where :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` are randomly generated
    164     flat-fading channel matrices and
    165     :math:`\mathbf{n}\in\mathbb{C}^{M}\sim\mathcal{CN}(0, N_o\mathbf{I})`
    166     is an AWGN vector that is optionally added.
    167 
    168     A :class:`~sionna.channel.SpatialCorrelation` can be configured and the
    169     channel realizations optionally returned. This is useful to simulate
    170     receiver algorithms with perfect channel knowledge.
    171 
    172     Parameters
    173     ----------
    174     num_tx_ant : int
    175         Number of transmit antennas.
    176 
    177     num_rx_ant : int
    178         Number of receive antennas.
    179 
    180     spatial_corr : SpatialCorrelation, None
    181         An instance of :class:`~sionna.channel.SpatialCorrelation` or `None`.
    182         Defaults to `None`.
    183 
    184     add_awgn: bool
    185         Indicates if AWGN noise should be added to the output.
    186         Defaults to `True`.
    187 
    188     return_channel: bool
    189         Indicates if the channel realizations should be returned.
    190         Defaults  to `False`.
    191 
    192     dtype : tf.complex64, tf.complex128
    193         The dtype of the output. Defaults to `tf.complex64`.
    194 
    195     Input
    196     -----
    197     (x, no) :
    198         Tuple or Tensor:
    199 
    200     x : [batch_size, num_tx_ant], tf.complex
    201         Tensor of transmit vectors.
    202 
    203     no : Scalar of Tensor, tf.float
    204         The noise power ``no`` is per complex dimension.
    205         Only required if ``add_awgn==True``.
    206         Will be broadcast to the dimensions of the channel output if needed.
    207         For more details, see :class:`~sionna.channel.AWGN`.
    208 
    209     Output
    210     ------
    211     (y, h) :
    212         Tuple or Tensor:
    213 
    214     y : [batch_size, num_rx_ant], ``dtype``
    215         Channel output.
    216 
    217     h : [batch_size, num_rx_ant, num_tx_ant], ``dtype``
    218         Channel realizations. Will only be returned if
    219         ``return_channel==True``.
    220     """
    221     def __init__(self,
    222                  num_tx_ant,
    223                  num_rx_ant,
    224                  spatial_corr=None,
    225                  add_awgn=True,
    226                  return_channel=False,
    227                  dtype=tf.complex64,
    228                  **kwargs):
    229         super().__init__(trainable=False, dtype=dtype, **kwargs)
    230         self._num_tx_ant = num_tx_ant
    231         self._num_rx_ant = num_rx_ant
    232         self._add_awgn = add_awgn
    233         self._return_channel = return_channel
    234         self._gen_chn = GenerateFlatFadingChannel(self._num_tx_ant,
    235                                                   self._num_rx_ant,
    236                                                   spatial_corr,
    237                                                   dtype=dtype)
    238         self._app_chn = ApplyFlatFadingChannel(add_awgn=add_awgn, dtype=dtype)
    239 
    240     @property
    241     def spatial_corr(self):
    242         """The :class:`~sionna.channel.SpatialCorrelation` to be used."""
    243         return self._gen_chn.spatial_corr
    244 
    245     @spatial_corr.setter
    246     def spatial_corr(self, value):
    247         self._gen_chn.spatial_corr = value
    248 
    249     @property
    250     def generate(self):
    251         """Calls the internal :class:`GenerateFlatFadingChannel`."""
    252         return self._gen_chn
    253 
    254     @property
    255     def apply(self):
    256         """Calls the internal :class:`ApplyFlatFadingChannel`."""
    257         return self._app_chn
    258 
    259     def call(self, inputs):
    260         if self._add_awgn:
    261             x, no = inputs
    262         else:
    263             x = inputs
    264 
    265         # Generate a batch of channel realizations
    266         batch_size = tf.shape(x)[0]
    267         h = self._gen_chn(batch_size)
    268 
    269         # Apply the channel to the input
    270         if self._add_awgn:
    271             y = self._app_chn([x, h, no])
    272         else:
    273             y = self._app_chn([x, h])
    274 
    275         if self._return_channel:
    276             return y, h
    277         else:
    278             return y