anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

pusch_precoder.py (3369B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 
      6 """PUSCH Precoding Layer for the nr (5G) sub-package of the Sionna library."""
      7 
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 
     11 class PUSCHPrecoder(Layer):
     12     # pylint: disable=line-too-long
     13     r"""
     14     PUSCHPrecoder(precoding_matrices, dtype=tf.complex64, **kwargs)
     15 
     16     Precodes a batch of modulated symbols mapped onto a resource grid
     17     for PUSCH transmissions. Each transmitter is assumed to have its
     18     own precoding matrix.
     19 
     20     Parameters
     21     ----------
     22     precoding_matrices : list, [num_tx, num_antenna_ports, num_layers]. tf.complex
     23         List of precoding matrices, one for each transmitter.
     24         All precoding matrices must have the same shape.
     25 
     26     dtype : One of [tf.complex64, tf.complex128]
     27         Dtype of inputs and outputs. Defaults to tf.complex64.
     28 
     29     Input
     30     -----
     31         : [batch_size, num_tx, num_layers, num_symbols_per_slot, num_subcarriers]
     32             Batch of resource grids to be precoded
     33 
     34     Output
     35     ------
     36         : [batch_size, num_tx, num_antenna_ports, num_symbols_per_slot, num_subcarriers]
     37             Batch of precoded resource grids
     38     """
     39     def __init__(self,
     40                  precoding_matrices,
     41                  dtype=tf.complex64,
     42                  **kwargs):
     43 
     44         assert dtype in [tf.complex64, tf.complex128], \
     45             "dtype must be tf.complex64 or tf.complex128"
     46         super().__init__(dtype=dtype, **kwargs)
     47 
     48         self._num_tx = len(precoding_matrices)
     49 
     50         # Check that all precoding matrices have the same shape
     51         shape = precoding_matrices[0].shape
     52         w_list = []
     53         for w in precoding_matrices:
     54             assert w.shape[0]==shape[0] and w.shape[1]==shape[1], \
     55                 "All precoding matrices must have the same shape"
     56             w_list.append(w)
     57 
     58         # w has shape:
     59         #[num_tx, num_antenna_ports, num_layers]
     60         self._w = tf.constant(w_list, self.dtype)
     61 
     62     def build(self, input_shape):
     63         _, num_tx, num_layers, _, _ = input_shape
     64         assert num_tx==len(self._w), \
     65             f"""The input shape is for {num_tx} transmitters, but you have
     66                 configured precoding matrices for {len(self._w)}."""
     67         assert num_layers==self._w[0].shape[1], \
     68             f"""You have configured precoding matrices for
     69                 {self._w[0].shape[1]} layers, but the input
     70                 provides {num_layers} layers."""
     71 
     72     def call(self, inputs):
     73 
     74         # inputs has shape:
     75         # [batch_size, num_tx, num_layers, num_symbols_per_slot,...
     76         #  ..., num_subcarriers]
     77 
     78         # Change ordering of dimensions:
     79         # [batch_size, num_symbols_per_slot, num_subcarriers, num_tx,...
     80         #  ..., num_layers]
     81         inputs = tf.transpose(inputs, [0, 3, 4, 1, 2])
     82 
     83         # Add dimension for matrix multiplication:
     84         inputs = tf.expand_dims(inputs, -1)
     85 
     86         # Precode:
     87         # [batch_size, num_symbols_per_slot, num_subcarriers,...
     88         #  ..., num_tx, num_antenna_ports]
     89         z = tf.squeeze(tf.matmul(self._w, inputs), -1)
     90 
     91         # Re-order:
     92         # [batch_size, num_tx, num_antenna_ports, num_symbols_per_slot,...
     93         #  ..., num_subcarriers]
     94         z = tf.transpose(z, [0, 3, 4, 1, 2])
     95 
     96         return z