pusch_precoder.py (3369B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 6 """PUSCH Precoding Layer for the nr (5G) sub-package of the Sionna library.""" 7 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 11 class PUSCHPrecoder(Layer): 12 # pylint: disable=line-too-long 13 r""" 14 PUSCHPrecoder(precoding_matrices, dtype=tf.complex64, **kwargs) 15 16 Precodes a batch of modulated symbols mapped onto a resource grid 17 for PUSCH transmissions. Each transmitter is assumed to have its 18 own precoding matrix. 19 20 Parameters 21 ---------- 22 precoding_matrices : list, [num_tx, num_antenna_ports, num_layers]. tf.complex 23 List of precoding matrices, one for each transmitter. 24 All precoding matrices must have the same shape. 25 26 dtype : One of [tf.complex64, tf.complex128] 27 Dtype of inputs and outputs. Defaults to tf.complex64. 28 29 Input 30 ----- 31 : [batch_size, num_tx, num_layers, num_symbols_per_slot, num_subcarriers] 32 Batch of resource grids to be precoded 33 34 Output 35 ------ 36 : [batch_size, num_tx, num_antenna_ports, num_symbols_per_slot, num_subcarriers] 37 Batch of precoded resource grids 38 """ 39 def __init__(self, 40 precoding_matrices, 41 dtype=tf.complex64, 42 **kwargs): 43 44 assert dtype in [tf.complex64, tf.complex128], \ 45 "dtype must be tf.complex64 or tf.complex128" 46 super().__init__(dtype=dtype, **kwargs) 47 48 self._num_tx = len(precoding_matrices) 49 50 # Check that all precoding matrices have the same shape 51 shape = precoding_matrices[0].shape 52 w_list = [] 53 for w in precoding_matrices: 54 assert w.shape[0]==shape[0] and w.shape[1]==shape[1], \ 55 "All precoding matrices must have the same shape" 56 w_list.append(w) 57 58 # w has shape: 59 #[num_tx, num_antenna_ports, num_layers] 60 self._w = tf.constant(w_list, self.dtype) 61 62 def build(self, input_shape): 63 _, num_tx, num_layers, _, _ = input_shape 64 assert num_tx==len(self._w), \ 65 f"""The input shape is for {num_tx} transmitters, but you have 66 configured precoding matrices for {len(self._w)}.""" 67 assert num_layers==self._w[0].shape[1], \ 68 f"""You have configured precoding matrices for 69 {self._w[0].shape[1]} layers, but the input 70 provides {num_layers} layers.""" 71 72 def call(self, inputs): 73 74 # inputs has shape: 75 # [batch_size, num_tx, num_layers, num_symbols_per_slot,... 76 # ..., num_subcarriers] 77 78 # Change ordering of dimensions: 79 # [batch_size, num_symbols_per_slot, num_subcarriers, num_tx,... 80 # ..., num_layers] 81 inputs = tf.transpose(inputs, [0, 3, 4, 1, 2]) 82 83 # Add dimension for matrix multiplication: 84 inputs = tf.expand_dims(inputs, -1) 85 86 # Precode: 87 # [batch_size, num_symbols_per_slot, num_subcarriers,... 88 # ..., num_tx, num_antenna_ports] 89 z = tf.squeeze(tf.matmul(self._w, inputs), -1) 90 91 # Re-order: 92 # [batch_size, num_tx, num_antenna_ports, num_symbols_per_slot,... 93 # ..., num_subcarriers] 94 z = tf.transpose(z, [0, 3, 4, 1, 2]) 95 96 return z