anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

equalization.py (20362B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition and functions related to OFDM channel equalization"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 import sionna
     10 from sionna.utils import flatten_dims, split_dim, flatten_last_dims, expand_to_rank
     11 from sionna.mimo import lmmse_equalizer, zf_equalizer, mf_equalizer
     12 from sionna.ofdm import RemoveNulledSubcarriers
     13 
     14 
     15 class OFDMEqualizer(Layer):
     16     # pylint: disable=line-too-long
     17     r"""OFDMEqualizer(equalizer, resource_grid, stream_management, dtype=tf.complex64, **kwargs)
     18 
     19     Layer that wraps a MIMO equalizer for use with the OFDM waveform.
     20 
     21     The parameter ``equalizer`` is a callable (e.g., a function) that
     22     implements a MIMO equalization algorithm for arbitrary batch dimensions.
     23 
     24     This class pre-processes the received resource grid ``y`` and channel
     25     estimate ``h_hat``, and computes for each receiver the
     26     noise-plus-interference covariance matrix according to the OFDM and stream
     27     configuration provided by the ``resource_grid`` and
     28     ``stream_management``, which also accounts for the channel
     29     estimation error variance ``err_var``. These quantities serve as input
     30     to the equalization algorithm that is implemented by the callable ``equalizer``.
     31     This layer computes soft-symbol estimates together with effective noise
     32     variances for all streams which can, e.g., be used by a
     33     :class:`~sionna.mapping.Demapper` to obtain LLRs.
     34 
     35     Note
     36     -----
     37     The callable ``equalizer`` must take three inputs:
     38 
     39     * **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals.
     40     * **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices.
     41     * **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices.
     42 
     43     It must generate two outputs:
     44 
     45     * **x_hat** ([...,num_streams_per_rx], tf.complex) -- 1+D tensor representing the estimated symbol vectors.
     46     * **no_eff** (tf.float) -- Tensor of the same shape as ``x_hat`` containing the effective noise variance estimates.
     47 
     48     Parameters
     49     ----------
     50     equalizer : Callable
     51         Callable object (e.g., a function) that implements a MIMO equalization
     52         algorithm for arbitrary batch dimensions
     53 
     54     resource_grid : ResourceGrid
     55         Instance of :class:`~sionna.ofdm.ResourceGrid`
     56 
     57     stream_management : StreamManagement
     58         Instance of :class:`~sionna.mimo.StreamManagement`
     59 
     60     dtype : tf.Dtype
     61         Datatype for internal calculations and the output dtype.
     62         Defaults to `tf.complex64`.
     63 
     64     Input
     65     -----
     66     (y, h_hat, err_var, no) :
     67         Tuple:
     68 
     69     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
     70         Received OFDM resource grid after cyclic prefix removal and FFT
     71 
     72     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
     73         Channel estimates for all streams from all transmitters
     74 
     75     err_var : [Broadcastable to shape of ``h_hat``], tf.float
     76         Variance of the channel estimation error
     77 
     78     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
     79         Variance of the AWGN
     80 
     81     Output
     82     ------
     83     x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
     84         Estimated symbols
     85 
     86     no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
     87         Effective noise variance for each estimated symbol
     88     """
     89     def __init__(self,
     90                  equalizer,
     91                  resource_grid,
     92                  stream_management,
     93                  dtype=tf.complex64,
     94                  **kwargs):
     95         super().__init__(dtype=dtype, **kwargs)
     96         assert callable(equalizer)
     97         assert isinstance(resource_grid, sionna.ofdm.ResourceGrid)
     98         assert isinstance(stream_management, sionna.mimo.StreamManagement)
     99         self._equalizer = equalizer
    100         self._resource_grid = resource_grid
    101         self._stream_management = stream_management
    102         self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid)
    103 
    104         # Precompute indices to extract data symbols
    105         mask = resource_grid.pilot_pattern.mask
    106         num_data_symbols = resource_grid.pilot_pattern.num_data_symbols
    107         data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING")
    108         self._data_ind = data_ind[...,:num_data_symbols]
    109 
    110     def call(self, inputs):
    111 
    112         y, h_hat, err_var, no = inputs
    113         # y has shape:
    114         # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size]
    115 
    116         # h_hat has shape:
    117         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,...
    118         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    119 
    120         # err_var has a shape that is broadcastable to h_hat
    121 
    122         # no has shape [batch_size, num_rx, num_rx_ant]
    123         # or just the first n dimensions of this
    124 
    125         # Remove nulled subcarriers from y (guards, dc). New shape:
    126         # [batch_size, num_rx, num_rx_ant, ...
    127         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    128         y_eff = self._removed_nulled_scs(y)
    129 
    130         ####################################################
    131         ### Prepare the observation y for MIMO detection ###
    132         ####################################################
    133         # Transpose y_eff to put num_rx_ant last. New shape:
    134         # [batch_size, num_rx, num_ofdm_symbols,...
    135         #  ..., num_effective_subcarriers, num_rx_ant]
    136         y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2])
    137         y_dt = tf.cast(y_dt, self._dtype)
    138 
    139         ##############################################
    140         ### Prepare the err_var for MIMO detection ###
    141         ##############################################
    142         # New shape is:
    143         # [batch_size, num_rx, num_ofdm_symbols,...
    144         #  ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams]
    145         err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat))
    146         err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4])
    147         err_var_dt = flatten_last_dims(err_var_dt, 2)
    148         err_var_dt = tf.cast(err_var_dt, self._dtype)
    149 
    150         ###############################
    151         ### Construct MIMO channels ###
    152         ###############################
    153 
    154         # Reshape h_hat for the construction of desired/interfering channels:
    155         # [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,...
    156         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    157         perm = [1, 3, 4, 0, 2, 5, 6]
    158         h_dt = tf.transpose(h_hat, perm)
    159 
    160         # Flatten first tthree dimensions:
    161         # [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ...
    162         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    163         h_dt = flatten_dims(h_dt, 3, 0)
    164 
    165         # Gather desired and undesired channels
    166         ind_desired = self._stream_management.detection_desired_ind
    167         ind_undesired = self._stream_management.detection_undesired_ind
    168         h_dt_desired = tf.gather(h_dt, ind_desired, axis=0)
    169         h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0)
    170 
    171         # Split first dimension to separate RX and TX:
    172         # [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ...
    173         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    174         h_dt_desired = split_dim(h_dt_desired,
    175                                  [self._stream_management.num_rx,
    176                                   self._stream_management.num_streams_per_rx],
    177                                  0)
    178         h_dt_undesired = split_dim(h_dt_undesired,
    179                                    [self._stream_management.num_rx, -1], 0)
    180 
    181         # Permutate dims to
    182         # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,..
    183         #  ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)]
    184         perm = [2, 0, 4, 5, 3, 1]
    185         h_dt_desired = tf.transpose(h_dt_desired, perm)
    186         h_dt_desired = tf.cast(h_dt_desired, self._dtype)
    187         h_dt_undesired = tf.transpose(h_dt_undesired, perm)
    188 
    189         ##################################
    190         ### Prepare the noise variance ###
    191         ##################################
    192         # no is first broadcast to [batch_size, num_rx, num_rx_ant]
    193         # then the rank is expanded to that of y
    194         # then it is transposed like y to the final shape
    195         # [batch_size, num_rx, num_ofdm_symbols,...
    196         #  ..., num_effective_subcarriers, num_rx_ant]
    197         no_dt = expand_to_rank(no, 3, -1)
    198         no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3])
    199         no_dt = expand_to_rank(no_dt, tf.rank(y), -1)
    200         no_dt = tf.transpose(no_dt, [0,1,3,4,2])
    201         no_dt = tf.cast(no_dt, self._dtype)
    202 
    203         ##################################################
    204         ### Compute the interference covariance matrix ###
    205         ##################################################
    206         # Covariance of undesired transmitters
    207         s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True)
    208 
    209         #Thermal noise
    210         s_no = tf.linalg.diag(no_dt)
    211 
    212         # Channel estimation errors
    213         # As we have only error variance information for each element,
    214         # we simply sum them across transmitters and build a
    215         # diagonal covariance matrix from this
    216         s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1))
    217 
    218         # Final covariance matrix
    219         s = s_inf + s_no + s_csi
    220         s = tf.cast(s, self._dtype)
    221 
    222         ############################################################
    223         ### Compute symbol estimate and effective noise variance ###
    224         ############################################################
    225         # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,...
    226         #  ..., num_stream_per_rx]
    227         x_hat, no_eff = self._equalizer(y_dt, h_dt_desired, s)
    228 
    229         ################################################
    230         ### Extract data symbols for all detected TX ###
    231         ################################################
    232         # Transpose tensor to shape
    233         # [num_rx, num_streams_per_rx, num_ofdm_symbols,...
    234         #  ..., num_effective_subcarriers, batch_size]
    235         x_hat = tf.transpose(x_hat, [1, 4, 2, 3, 0])
    236         no_eff = tf.transpose(no_eff, [1, 4, 2, 3, 0])
    237 
    238         # Merge num_rx amd num_streams_per_rx
    239         # [num_rx * num_streams_per_rx, num_ofdm_symbols,...
    240         #  ...,num_effective_subcarriers, batch_size]
    241         x_hat = flatten_dims(x_hat, 2, 0)
    242         no_eff = flatten_dims(no_eff, 2, 0)
    243 
    244         # Put first dimension into the right ordering
    245         stream_ind = self._stream_management.stream_ind
    246         x_hat = tf.gather(x_hat, stream_ind, axis=0)
    247         no_eff = tf.gather(no_eff, stream_ind, axis=0)
    248 
    249         # Reshape first dimensions to [num_tx, num_streams] so that
    250         # we can compared to the way the streams were created.
    251         # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,...
    252         #  ..., batch_size]
    253         num_streams = self._stream_management.num_streams_per_tx
    254         num_tx = self._stream_management.num_tx
    255         x_hat = split_dim(x_hat, [num_tx, num_streams], 0)
    256         no_eff = split_dim(no_eff, [num_tx, num_streams], 0)
    257 
    258         # Flatten resource grid dimensions
    259         # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,...
    260         #  ..., batch_size]
    261         x_hat = flatten_dims(x_hat, 2, 2)
    262         no_eff = flatten_dims(no_eff, 2, 2)
    263 
    264         # Broadcast no_eff to the shape of x_hat
    265         no_eff = tf.broadcast_to(no_eff, tf.shape(x_hat))
    266 
    267         # Gather data symbols
    268         # [num_tx, num_streams, num_data_symbols, batch_size]
    269         x_hat = tf.gather(x_hat, self._data_ind, batch_dims=2, axis=2)
    270         no_eff = tf.gather(no_eff, self._data_ind, batch_dims=2, axis=2)
    271 
    272         # Put batch_dim first
    273         # [batch_size, num_tx, num_streams, num_data_symbols]
    274         x_hat = tf.transpose(x_hat, [3, 0, 1, 2])
    275         no_eff = tf.transpose(no_eff, [3, 0, 1, 2])
    276 
    277         return (x_hat, no_eff)
    278 
    279 
    280 class LMMSEEqualizer(OFDMEqualizer):
    281     # pylint: disable=line-too-long
    282     """LMMSEEqualizer(resource_grid, stream_management, whiten_interference=True, dtype=tf.complex64, **kwargs)
    283 
    284     LMMSE equalization for OFDM MIMO transmissions.
    285 
    286     This layer computes linear minimum mean squared error (LMMSE) equalization
    287     for OFDM MIMO transmissions. The OFDM and stream configuration are provided
    288     by a :class:`~sionna.ofdm.ResourceGrid` and
    289     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    290     detection algorithm is the :meth:`~sionna.mimo.lmmse_equalizer`. The layer
    291     computes soft-symbol estimates together with effective noise variances
    292     for all streams which can, e.g., be used by a
    293     :class:`~sionna.mapping.Demapper` to obtain LLRs.
    294 
    295     Parameters
    296     ----------
    297     resource_grid : ResourceGrid
    298         Instance of :class:`~sionna.ofdm.ResourceGrid`
    299 
    300     stream_management : StreamManagement
    301         Instance of :class:`~sionna.mimo.StreamManagement`
    302 
    303     whiten_interference : bool
    304         If `True` (default), the interference is first whitened before equalization.
    305         In this case, an alternative expression for the receive filter is used which
    306         can be numerically more stable.
    307 
    308     dtype : tf.Dtype
    309         Datatype for internal calculations and the output dtype.
    310         Defaults to `tf.complex64`.
    311 
    312     Input
    313     -----
    314     (y, h_hat, err_var, no) :
    315         Tuple:
    316 
    317     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    318         Received OFDM resource grid after cyclic prefix removal and FFT
    319 
    320     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    321         Channel estimates for all streams from all transmitters
    322 
    323     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    324         Variance of the channel estimation error
    325 
    326     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    327         Variance of the AWGN
    328 
    329     Output
    330     ------
    331     x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
    332         Estimated symbols
    333 
    334     no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
    335         Effective noise variance for each estimated symbol
    336 
    337     Note
    338     ----
    339     If you want to use this layer in Graph mode with XLA, i.e., within
    340     a function that is decorated with ``@tf.function(jit_compile=True)``,
    341     you must set ``sionna.Config.xla_compat=true``.
    342     See :py:attr:`~sionna.Config.xla_compat`.
    343     """
    344     def __init__(self,
    345                  resource_grid,
    346                  stream_management,
    347                  whiten_interference=True,
    348                  dtype=tf.complex64,
    349                  **kwargs):
    350 
    351         def equalizer(y, h, s):
    352             return lmmse_equalizer(y, h, s, whiten_interference)
    353 
    354         super().__init__(equalizer=equalizer,
    355                          resource_grid=resource_grid,
    356                          stream_management=stream_management,
    357                          dtype=dtype, **kwargs)
    358 
    359 
    360 class ZFEqualizer(OFDMEqualizer):
    361     # pylint: disable=line-too-long
    362     """ZFEqualizer(resource_grid, stream_management, dtype=tf.complex64, **kwargs)
    363 
    364     ZF equalization for OFDM MIMO transmissions.
    365 
    366     This layer computes zero-forcing (ZF) equalization
    367     for OFDM MIMO transmissions. The OFDM and stream configuration are provided
    368     by a :class:`~sionna.ofdm.ResourceGrid` and
    369     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    370     detection algorithm is the :meth:`~sionna.mimo.zf_equalizer`. The layer
    371     computes soft-symbol estimates together with effective noise variances
    372     for all streams which can, e.g., be used by a
    373     :class:`~sionna.mapping.Demapper` to obtain LLRs.
    374 
    375     Parameters
    376     ----------
    377     resource_grid : ResourceGrid
    378         An instance of :class:`~sionna.ofdm.ResourceGrid`.
    379 
    380     stream_management : StreamManagement
    381         An instance of :class:`~sionna.mimo.StreamManagement`.
    382 
    383     dtype : tf.Dtype
    384         Datatype for internal calculations and the output dtype.
    385         Defaults to `tf.complex64`.
    386 
    387     Input
    388     -----
    389     (y, h_hat, err_var, no) :
    390         Tuple:
    391 
    392     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    393         Received OFDM resource grid after cyclic prefix removal and FFT
    394 
    395     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    396         Channel estimates for all streams from all transmitters
    397 
    398     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    399         Variance of the channel estimation error
    400 
    401     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    402         Variance of the AWGN
    403 
    404     Output
    405     ------
    406     x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
    407         Estimated symbols
    408 
    409     no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
    410         Effective noise variance for each estimated symbol
    411 
    412     Note
    413     ----
    414     If you want to use this layer in Graph mode with XLA, i.e., within
    415     a function that is decorated with ``@tf.function(jit_compile=True)``,
    416     you must set ``sionna.Config.xla_compat=true``.
    417     See :py:attr:`~sionna.Config.xla_compat`.
    418     """
    419     def __init__(self,
    420                  resource_grid,
    421                  stream_management,
    422                  dtype=tf.complex64,
    423                  **kwargs):
    424         super().__init__(equalizer=zf_equalizer,
    425                          resource_grid=resource_grid,
    426                          stream_management=stream_management,
    427                          dtype=dtype, **kwargs)
    428 
    429 
    430 class MFEqualizer(OFDMEqualizer):
    431     # pylint: disable=line-too-long
    432     """MFEqualizer(resource_grid, stream_management, dtype=tf.complex64, **kwargs)
    433 
    434     MF equalization for OFDM MIMO transmissions.
    435 
    436     This layer computes matched filter (MF) equalization
    437     for OFDM MIMO transmissions. The OFDM and stream configuration are provided
    438     by a :class:`~sionna.ofdm.ResourceGrid` and
    439     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    440     detection algorithm is the :meth:`~sionna.mimo.mf_equalizer`. The layer
    441     computes soft-symbol estimates together with effective noise variances
    442     for all streams which can, e.g., be used by a
    443     :class:`~sionna.mapping.Demapper` to obtain LLRs.
    444 
    445     Parameters
    446     ----------
    447     resource_grid : ResourceGrid
    448         An instance of :class:`~sionna.ofdm.ResourceGrid`.
    449 
    450     stream_management : StreamManagement
    451         An instance of :class:`~sionna.mimo.StreamManagement`.
    452 
    453     dtype : tf.Dtype
    454         Datatype for internal calculations and the output dtype.
    455         Defaults to `tf.complex64`.
    456 
    457     Input
    458     -----
    459     (y, h_hat, err_var, no) :
    460         Tuple:
    461 
    462     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    463         Received OFDM resource grid after cyclic prefix removal and FFT
    464 
    465     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    466         Channel estimates for all streams from all transmitters
    467 
    468     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    469         Variance of the channel estimation error
    470 
    471     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    472         Variance of the AWGN
    473 
    474     Output
    475     ------
    476     x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
    477         Estimated symbols
    478 
    479     no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
    480         Effective noise variance for each estimated symbol
    481 
    482     Note
    483     ----
    484     If you want to use this layer in Graph mode with XLA, i.e., within
    485     a function that is decorated with ``@tf.function(jit_compile=True)``,
    486     you must set ``sionna.Config.xla_compat=true``.
    487     See :py:attr:`~sionna.Config.xla_compat`.
    488     """
    489     def __init__(self,
    490                  resource_grid,
    491                  stream_management,
    492                  dtype=tf.complex64,
    493                  **kwargs):
    494         super().__init__(equalizer=mf_equalizer,
    495                          resource_grid=resource_grid,
    496                          stream_management=stream_management,
    497                          dtype=dtype, **kwargs)