anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

precoding.py (7043B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition and functions related to OFDM transmit precoding"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 import sionna
     10 from sionna.utils import flatten_dims
     11 from sionna.mimo import zero_forcing_precoder
     12 from sionna.ofdm import RemoveNulledSubcarriers
     13 
     14 
     15 class ZFPrecoder(Layer):
     16     # pylint: disable=line-too-long
     17     r"""ZFPrecoder(resource_grid, stream_management, return_effective_channel=False, dtype=tf.complex64, **kwargs)
     18 
     19     Zero-forcing precoding for multi-antenna transmissions.
     20 
     21     This layer precodes a tensor containing OFDM resource grids using
     22     the :meth:`~sionna.mimo.zero_forcing_precoder`. For every
     23     transmitter, the channels to all intended receivers are gathered
     24     into a channel matrix, based on the which the precoding matrix
     25     is computed and the input tensor is precoded. The layer also outputs
     26     optionally the effective channel after precoding for each stream.
     27 
     28     Parameters
     29     ----------
     30     resource_grid : ResourceGrid
     31         An instance of :class:`~sionna.ofdm.ResourceGrid`.
     32 
     33     stream_management : StreamManagement
     34         An instance of :class:`~sionna.mimo.StreamManagement`.
     35 
     36     return_effective_channel : bool
     37         Indicates if the effective channel after precoding should be returned.
     38 
     39     dtype : tf.Dtype
     40         Datatype for internal calculations and the output dtype.
     41         Defaults to `tf.complex64`.
     42 
     43     Input
     44     -----
     45     (x, h) :
     46         Tuple:
     47 
     48     x : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
     49         Tensor containing the resource grid to be precoded.
     50 
     51     h : [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm, fft_size], tf.complex
     52         Tensor containing the channel knowledge based on which the precoding
     53         is computed.
     54 
     55     Output
     56     ------
     57     x_precoded : [batch_size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size], tf.complex
     58         The precoded resource grids.
     59 
     60     h_eff : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm, num_effective_subcarriers], tf.complex
     61         Only returned if ``return_effective_channel=True``.
     62         The effectice channels for all streams after precoding. Can be used to
     63         simulate perfect channel state information (CSI) at the receivers.
     64         Nulled subcarriers are automatically removed to be compliant with the
     65         behavior of a channel estimator.
     66 
     67     Note
     68     ----
     69     If you want to use this layer in Graph mode with XLA, i.e., within
     70     a function that is decorated with ``@tf.function(jit_compile=True)``,
     71     you must set ``sionna.Config.xla_compat=true``.
     72     See :py:attr:`~sionna.Config.xla_compat`.
     73     """
     74     def __init__(self,
     75                  resource_grid,
     76                  stream_management,
     77                  return_effective_channel=False,
     78                  dtype=tf.complex64,
     79                  **kwargs):
     80         super().__init__(dtype=dtype, **kwargs)
     81         assert isinstance(resource_grid, sionna.ofdm.ResourceGrid)
     82         assert isinstance(stream_management, sionna.mimo.StreamManagement)
     83         self._resource_grid = resource_grid
     84         self._stream_management = stream_management
     85         self._return_effective_channel = return_effective_channel
     86         self._remove_nulled_scs = RemoveNulledSubcarriers(self._resource_grid)
     87 
     88     def _compute_effective_channel(self, h, g):
     89         """Compute effective channel after precoding"""
     90 
     91         # Input dimensions:
     92         # h: [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant,...
     93         #     ..., num_ofdm, fft_size]
     94         # g: [batch_size, num_tx, num_ofdm_symbols, fft_size, num_tx_ant,
     95         #     ..., num_streams_per_tx]
     96 
     97         # Transpose h to shape:
     98         # [batch_size, num_rx, num_tx, num_ofdm, fft_size, num_rx_ant,...
     99         #  ..., num_tx_ant]
    100         h = tf.transpose(h, [0, 1, 3, 5, 6, 2, 4])
    101         h = tf.cast(h, g.dtype)
    102 
    103         # Add one dummy dimension to g to be broadcastable to h:
    104         # [batch_size, 1, num_tx, num_ofdm_symbols, fft_size, num_tx_ant,...
    105         #  ..., num_streams_per_tx]
    106         g = tf.expand_dims(g, 1)
    107 
    108         # Compute post precoding channel:
    109         # [batch_size, num_rx, num_tx, num_ofdm, fft_size, num_rx_ant,...
    110         #  ..., num_streams_per_tx]
    111         h_eff = tf.matmul(h, g)
    112 
    113         # Permute dimensions to common format of channel tensors:
    114         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,...
    115         #  ..., num_ofdm, fft_size]
    116         h_eff = tf.transpose(h_eff, [0, 1, 5, 2, 6, 3, 4])
    117 
    118         # Remove nulled subcarriers:
    119         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,...
    120         #  ..., num_ofdm, num_effective_subcarriers]
    121         h_eff = self._remove_nulled_scs(h_eff)
    122 
    123         return h_eff
    124 
    125     def call(self, inputs):
    126 
    127         x, h = inputs
    128         # x has shape
    129         # [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size]
    130         #
    131         # h has shape
    132         # [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm,...
    133         # ..., fft_size]
    134 
    135         ###
    136         ### Transformations to bring h and x in the desired shapes
    137         ###
    138 
    139         # Transpose x:
    140         #[batch_size, num_tx, num_ofdm_symbols, fft_size, num_streams_per_tx]
    141         x_precoded = tf.transpose(x, [0, 1, 3, 4, 2])
    142         x_precoded = tf.cast(x_precoded, self._dtype)
    143 
    144         # Transpose h:
    145         # [num_tx, num_rx, num_rx_ant, num_tx_ant, num_ofdm_symbols,...
    146         #  ..., fft_size, batch_size]
    147         h_pc = tf.transpose(h, [3, 1, 2, 4, 5, 6, 0])
    148 
    149         # Gather desired channel for precoding:
    150         # [num_tx, num_rx_per_tx, num_rx_ant, num_tx_ant, num_ofdm_symbols,...
    151         #  ..., fft_size, batch_size]
    152         h_pc_desired = tf.gather(h_pc, self._stream_management.precoding_ind,
    153                                  axis=1, batch_dims=1)
    154 
    155         # Flatten dims 2,3:
    156         # [num_tx, num_rx_per_tx * num_rx_ant, num_tx_ant, num_ofdm_symbols,...
    157         #  ..., fft_size, batch_size]
    158         h_pc_desired = flatten_dims(h_pc_desired, 2, axis=1)
    159 
    160         # Transpose:
    161         # [batch_size, num_tx, num_ofdm_symbols, fft_size,...
    162         #  ..., num_streams_per_tx, num_tx_ant]
    163         h_pc_desired = tf.transpose(h_pc_desired, [5, 0, 3, 4, 1, 2])
    164         h_pc_desired = tf.cast(h_pc_desired, self._dtype)
    165 
    166         ###
    167         ### ZF precoding
    168         ###
    169         x_precoded, g = zero_forcing_precoder(x_precoded,
    170                                               h_pc_desired,
    171                                               return_precoding_matrix=True)
    172 
    173         # Transpose output to desired shape:
    174         #[batch_size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size]
    175         x_precoded = tf.transpose(x_precoded, [0, 1, 4, 2, 3])
    176 
    177         if self._return_effective_channel:
    178             h_eff = self._compute_effective_channel(h, g)
    179             return (x_precoded, h_eff)
    180         else:
    181             return x_precoded