precoding.py (7043B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition and functions related to OFDM transmit precoding""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 import sionna 10 from sionna.utils import flatten_dims 11 from sionna.mimo import zero_forcing_precoder 12 from sionna.ofdm import RemoveNulledSubcarriers 13 14 15 class ZFPrecoder(Layer): 16 # pylint: disable=line-too-long 17 r"""ZFPrecoder(resource_grid, stream_management, return_effective_channel=False, dtype=tf.complex64, **kwargs) 18 19 Zero-forcing precoding for multi-antenna transmissions. 20 21 This layer precodes a tensor containing OFDM resource grids using 22 the :meth:`~sionna.mimo.zero_forcing_precoder`. For every 23 transmitter, the channels to all intended receivers are gathered 24 into a channel matrix, based on the which the precoding matrix 25 is computed and the input tensor is precoded. The layer also outputs 26 optionally the effective channel after precoding for each stream. 27 28 Parameters 29 ---------- 30 resource_grid : ResourceGrid 31 An instance of :class:`~sionna.ofdm.ResourceGrid`. 32 33 stream_management : StreamManagement 34 An instance of :class:`~sionna.mimo.StreamManagement`. 35 36 return_effective_channel : bool 37 Indicates if the effective channel after precoding should be returned. 38 39 dtype : tf.Dtype 40 Datatype for internal calculations and the output dtype. 41 Defaults to `tf.complex64`. 42 43 Input 44 ----- 45 (x, h) : 46 Tuple: 47 48 x : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex 49 Tensor containing the resource grid to be precoded. 50 51 h : [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm, fft_size], tf.complex 52 Tensor containing the channel knowledge based on which the precoding 53 is computed. 54 55 Output 56 ------ 57 x_precoded : [batch_size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size], tf.complex 58 The precoded resource grids. 59 60 h_eff : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm, num_effective_subcarriers], tf.complex 61 Only returned if ``return_effective_channel=True``. 62 The effectice channels for all streams after precoding. Can be used to 63 simulate perfect channel state information (CSI) at the receivers. 64 Nulled subcarriers are automatically removed to be compliant with the 65 behavior of a channel estimator. 66 67 Note 68 ---- 69 If you want to use this layer in Graph mode with XLA, i.e., within 70 a function that is decorated with ``@tf.function(jit_compile=True)``, 71 you must set ``sionna.Config.xla_compat=true``. 72 See :py:attr:`~sionna.Config.xla_compat`. 73 """ 74 def __init__(self, 75 resource_grid, 76 stream_management, 77 return_effective_channel=False, 78 dtype=tf.complex64, 79 **kwargs): 80 super().__init__(dtype=dtype, **kwargs) 81 assert isinstance(resource_grid, sionna.ofdm.ResourceGrid) 82 assert isinstance(stream_management, sionna.mimo.StreamManagement) 83 self._resource_grid = resource_grid 84 self._stream_management = stream_management 85 self._return_effective_channel = return_effective_channel 86 self._remove_nulled_scs = RemoveNulledSubcarriers(self._resource_grid) 87 88 def _compute_effective_channel(self, h, g): 89 """Compute effective channel after precoding""" 90 91 # Input dimensions: 92 # h: [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant,... 93 # ..., num_ofdm, fft_size] 94 # g: [batch_size, num_tx, num_ofdm_symbols, fft_size, num_tx_ant, 95 # ..., num_streams_per_tx] 96 97 # Transpose h to shape: 98 # [batch_size, num_rx, num_tx, num_ofdm, fft_size, num_rx_ant,... 99 # ..., num_tx_ant] 100 h = tf.transpose(h, [0, 1, 3, 5, 6, 2, 4]) 101 h = tf.cast(h, g.dtype) 102 103 # Add one dummy dimension to g to be broadcastable to h: 104 # [batch_size, 1, num_tx, num_ofdm_symbols, fft_size, num_tx_ant,... 105 # ..., num_streams_per_tx] 106 g = tf.expand_dims(g, 1) 107 108 # Compute post precoding channel: 109 # [batch_size, num_rx, num_tx, num_ofdm, fft_size, num_rx_ant,... 110 # ..., num_streams_per_tx] 111 h_eff = tf.matmul(h, g) 112 113 # Permute dimensions to common format of channel tensors: 114 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,... 115 # ..., num_ofdm, fft_size] 116 h_eff = tf.transpose(h_eff, [0, 1, 5, 2, 6, 3, 4]) 117 118 # Remove nulled subcarriers: 119 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,... 120 # ..., num_ofdm, num_effective_subcarriers] 121 h_eff = self._remove_nulled_scs(h_eff) 122 123 return h_eff 124 125 def call(self, inputs): 126 127 x, h = inputs 128 # x has shape 129 # [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size] 130 # 131 # h has shape 132 # [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm,... 133 # ..., fft_size] 134 135 ### 136 ### Transformations to bring h and x in the desired shapes 137 ### 138 139 # Transpose x: 140 #[batch_size, num_tx, num_ofdm_symbols, fft_size, num_streams_per_tx] 141 x_precoded = tf.transpose(x, [0, 1, 3, 4, 2]) 142 x_precoded = tf.cast(x_precoded, self._dtype) 143 144 # Transpose h: 145 # [num_tx, num_rx, num_rx_ant, num_tx_ant, num_ofdm_symbols,... 146 # ..., fft_size, batch_size] 147 h_pc = tf.transpose(h, [3, 1, 2, 4, 5, 6, 0]) 148 149 # Gather desired channel for precoding: 150 # [num_tx, num_rx_per_tx, num_rx_ant, num_tx_ant, num_ofdm_symbols,... 151 # ..., fft_size, batch_size] 152 h_pc_desired = tf.gather(h_pc, self._stream_management.precoding_ind, 153 axis=1, batch_dims=1) 154 155 # Flatten dims 2,3: 156 # [num_tx, num_rx_per_tx * num_rx_ant, num_tx_ant, num_ofdm_symbols,... 157 # ..., fft_size, batch_size] 158 h_pc_desired = flatten_dims(h_pc_desired, 2, axis=1) 159 160 # Transpose: 161 # [batch_size, num_tx, num_ofdm_symbols, fft_size,... 162 # ..., num_streams_per_tx, num_tx_ant] 163 h_pc_desired = tf.transpose(h_pc_desired, [5, 0, 3, 4, 1, 2]) 164 h_pc_desired = tf.cast(h_pc_desired, self._dtype) 165 166 ### 167 ### ZF precoding 168 ### 169 x_precoded, g = zero_forcing_precoder(x_precoded, 170 h_pc_desired, 171 return_precoding_matrix=True) 172 173 # Transpose output to desired shape: 174 #[batch_size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size] 175 x_precoded = tf.transpose(x_precoded, [0, 1, 4, 2, 3]) 176 177 if self._return_effective_channel: 178 h_eff = self._compute_effective_channel(h, g) 179 return (x_precoded, h_eff) 180 else: 181 return x_precoded