precoding.py (10200B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Classes and functions related to MIMO transmit precoding""" 6 7 import tensorflow as tf 8 from sionna.utils import matrix_inv 9 from sionna import PI 10 import math 11 12 def zero_forcing_precoder(x, h, return_precoding_matrix=False): 13 # pylint: disable=line-too-long 14 r"""Zero-Forcing (ZF) Precoder 15 16 This function implements ZF precoding for a MIMO link, assuming the 17 following model: 18 19 .. math:: 20 21 \mathbf{y} = \mathbf{H}\mathbf{G}\mathbf{x} + \mathbf{n} 22 23 where :math:`\mathbf{y}\in\mathbb{C}^K` is the received signal vector, 24 :math:`\mathbf{H}\in\mathbb{C}^{K\times M}` is the known channel matrix, 25 :math:`\mathbf{G}\in\mathbb{C}^{M\times K}` is the precoding matrix, 26 :math:`\mathbf{x}\in\mathbb{C}^K` is the symbol vector to be precoded, 27 and :math:`\mathbf{n}\in\mathbb{C}^K` is a noise vector. It is assumed that 28 :math:`K\le M`. 29 30 The precoding matrix :math:`\mathbf{G}` is defined as (Eq. 4.37) [BHS2017]_ : 31 32 .. math:: 33 34 \mathbf{G} = \mathbf{V}\mathbf{D} 35 36 where 37 38 .. math:: 39 40 \mathbf{V} &= \mathbf{H}^{\mathsf{H}}\left(\mathbf{H} \mathbf{H}^{\mathsf{H}}\right)^{-1}\\ 41 \mathbf{D} &= \mathop{\text{diag}}\left( \lVert \mathbf{v}_{k} \rVert_2^{-1}, k=0,\dots,K-1 \right). 42 43 This ensures that each stream is precoded with a unit-norm vector, 44 i.e., :math:`\mathop{\text{tr}}\left(\mathbf{G}\mathbf{G}^{\mathsf{H}}\right)=K`. 45 The function returns the precoded vector :math:`\mathbf{G}\mathbf{x}`. 46 47 Input 48 ----- 49 x : [...,K], tf.complex 50 1+D tensor containing the symbol vectors to be precoded. 51 52 h : [...,K,M], tf.complex 53 2+D tensor containing the channel matrices 54 55 return_precoding_matrices : bool 56 Indicates if the precoding matrices should be returned or not. 57 Defaults to False. 58 59 Output 60 ------- 61 x_precoded : [...,M], tf.complex 62 Tensor of the same shape and dtype as ``x`` apart from the last 63 dimensions that has changed from `K` to `M`. It contains the 64 precoded symbol vectors. 65 66 g : [...,M,K], tf.complex 67 2+D tensor containing the precoding matrices. It is only returned 68 if ``return_precoding_matrices=True``. 69 70 Note 71 ---- 72 If you want to use this function in Graph mode with XLA, i.e., within 73 a function that is decorated with ``@tf.function(jit_compile=True)``, 74 you must set ``sionna.Config.xla_compat=true``. 75 See :py:attr:`~sionna.Config.xla_compat`. 76 """ 77 78 # Compute pseudo inverse for precoding 79 g = tf.matmul(h, h, adjoint_b=True) 80 g = tf.matmul(h, matrix_inv(g), adjoint_a=True) 81 82 # Normalize each column to unit power 83 norm = tf.sqrt(tf.reduce_sum(tf.abs(g)**2, axis=-2, keepdims=True)) 84 g = g/tf.cast(norm, g.dtype) 85 86 # Expand last dim of `x` for precoding 87 x_precoded = tf.expand_dims(x, -1) 88 89 # Precode 90 x_precoded = tf.squeeze(tf.matmul(g, x_precoded), -1) 91 92 if return_precoding_matrix: 93 return (x_precoded, g) 94 else: 95 return x_precoded 96 97 def grid_of_beams_dft_ula(num_ant, 98 oversmpl=1): 99 # pylint: disable=line-too-long 100 r""" Computes the Discrete Fourier Transform (DFT) Grid of Beam (GoB) 101 coefficients for a uniform linear array (ULA) 102 103 The coefficient applied to antenna :math:`n` for beam :math:`m` is expressed 104 as: 105 106 .. math:: 107 c_n^m = e^{\frac{2\pi n m}{N O}}, \quad n=0,\dots,N-1, \ m=0,\dots,NO 108 109 where :math:`N` is the number of antennas ``num_ant`` and :math:`O` is the oversampling 110 factor ``oversmpl``. 111 112 Note that the main lobe of beam :math:`m` points in the azimuth direction 113 :math:`\theta = \mathrm{arc sin} \left( 2\frac{m}{N} \right)` if :math:`m\le 114 N/2` and :math:`\theta = \mathrm{arc sin} \left( 2\frac{m-N}{N} \right)` if 115 :math:`m\ge N/2`, where :math:`\theta=0` defines the perpendicular to the 116 antenna array. 117 118 Input 119 ------ 120 num_ant : int 121 Number of antennas 122 123 oversmpl : int 124 Oversampling factor 125 126 Output 127 ------- 128 gob : [num_ant x oversmpl, num_ant], tf.complex 129 The :math:`m`-th row contains the `num_ant` antenna coefficients for 130 the :math:`m`-th DFT beam 131 """ 132 oversmpl = int(oversmpl) 133 134 # Beam indices: [0, .., num_ant * oversmpl - 1] 135 beam_ind = tf.range(num_ant * oversmpl, dtype=tf.float32)[:, tf.newaxis] 136 137 # Antenna indices: [0, .., num_ant - 1] 138 antenna_ind = tf.range(num_ant, dtype=tf.float32)[tf.newaxis, :] 139 140 # Combine real and imaginary part and normalize power to 1 141 phases = 2 * PI * beam_ind * antenna_ind / (num_ant * oversmpl) 142 gob = tf.complex(tf.cos(phases), tf.sin(phases)) / math.sqrt(num_ant) 143 return gob 144 145 def grid_of_beams_dft(num_ant_v, 146 num_ant_h, 147 oversmpl_v=1, 148 oversmpl_h=1): 149 # pylint: disable=line-too-long 150 r""" Computes the Discrete Fourier Transform (DFT) Grid of Beam (GoB) 151 coefficients for a uniform rectangular array (URA) 152 153 GoB indices are arranged over a 2D grid indexed by :math:`(m_v,m_h)`. 154 The coefficient of the beam with index :math:`(m_v,m_h)` applied to the 155 antenna located at row :math:`n_v` and column :math:`n_h` of the rectangular 156 array is expressed as: 157 158 .. math:: 159 c_{n_v,n_h}^{m_v,m_h} = e^{\frac{2\pi n_h m_v}{N_h O_h}} e^{\frac{2\pi n_h m_h}{N_v O_v}} 160 161 where :math:`n_v=0,\dots,N_v-1`, :math:`n_h=0,\dots,N_h-1`, 162 :math:`m_v=0,\dots,N_v O_v`, :math:`m_h=0,\dots,N_h O_h`, :math:`N` is the 163 number of antennas ``num_ant`` and :math:`O_v,O_h` are the oversampling 164 factor ``oversmpl_v``, ``oversmpl_h`` in the vertical and 165 horizontal direction, respectively. 166 167 We can rewrite more concisely the matrix coefficients 168 :math:`c^{m_v,m_h}` as follows: 169 170 .. math:: 171 c^{m_v,m_h} = c^{m_v} \otimes c^{m_h} 172 173 where :math:`\otimes` denotes the Kronecker product and 174 :math:`c^{m_v},c^{m_h}` are the ULA DFT beams computed as in 175 :func:`~sionna.mimo.grid_of_beams_dft_ula` . 176 177 Such a DFT GoB is, e.g., defined in Section 5.2.2.2.1 [3GPP38214]_. 178 179 Input 180 ------ 181 num_ant_v : int 182 Number of antenna rows (i.e., in vertical direction) of the rectangular 183 array 184 185 num_ant_h : int 186 Number of antenna columns (i.e., in horizontal direction) of the 187 rectangular array. 188 189 oversmpl_v : int 190 Oversampling factor in vertical direction 191 192 oversmpl_h : int 193 Oversampling factor in horizontal direction 194 195 Output 196 ------- 197 gob : [num_ant_v x oversmpl_v, num_ant_h x oversmpl_h, num_ant_v x num_ant_h], tf.complex 198 The elements :math:`[m_v,m_h,:]` contain the antenna coefficients of the 199 DFT beam with index pair :math:`(m_v,m_h)`. 200 """ 201 202 # Compute the DFT coefficients to be applied in the vertical direction 203 gob_v = grid_of_beams_dft_ula(num_ant_v, oversmpl=oversmpl_v) 204 gob_v = gob_v[:, tf.newaxis, :, tf.newaxis] 205 206 # Compute the DFT coefficients to be applied in the horizontal direction 207 gob_h = grid_of_beams_dft_ula(num_ant_h, oversmpl=oversmpl_h) 208 gob_h = gob_h[tf.newaxis, :, tf.newaxis, :] 209 210 # Kronecker product: 211 # [num_ant_v * oversmpl_v , num_ant_h * oversmpl_v, num_ant_v, num_ant_h] 212 coef_vh = tf.math.multiply(gob_h, gob_v) 213 # Flatten the last two dimensions to produce 1-dimensional precoding vectors 214 # [num_ant_v * oversmpl_v , num_ant_h * oversmpl_v, num_ant_v x num_ant_h] 215 coef_vh = flatten_precoding_mat(coef_vh) 216 return coef_vh 217 218 def flatten_precoding_mat(precoding_mat, by_column=True): 219 # pylint: disable=line-too-long 220 r"""Flattens a [..., num_ant_v, num_ant_h] precoding matrix associated with 221 a rectangular array by producing a [..., num_ant_v x num_ant_h] precoding vector. 222 223 Input 224 ------ 225 precoding_mat : [..., num_antennas_vertical, num_antennas_horizontal], tf.complex 226 Precoding matrix. The element :math:`(i,j)` contains the precoding 227 coefficient of the antenna element located at row :math:`i` and column 228 :math:`j` of a rectangular antenna array. 229 230 by_column : bool 231 If `True`, then flattening occurs on a per-column basis, i.e., the first 232 column is appended to the second, and so on. Else, flattening is performed on 233 a per-row basis. 234 235 Output 236 ------- 237 : [..., num_antennas_vertical x num_antennas_horizontal], tf.complex 238 Flattened precoding vector 239 """ 240 241 # Transpose the last two dimensions 242 if by_column: 243 precoding_mat = tf.linalg.matrix_transpose(precoding_mat) 244 # Flatten the last two dimensions 245 precoding_vec = tf.reshape( 246 precoding_mat, precoding_mat.shape[:-2] + [math.prod(precoding_mat.shape[2:])]) 247 return precoding_vec 248 249 def normalize_precoding_power(precoding_vec, dtype=None, tx_power_list=None): 250 # pylint: disable=line-too-long 251 r""" Normalizes the beam coefficient power to 1 by default, or to 252 ``tx_power_list`` if provided as input. 253 254 Input 255 ------ 256 precoding_vec : [N,M], tf.complex 257 Each row contains a set of antenna coefficients whose power is to be normalized. 258 259 dtype : dtype 260 dtype of the output. Defaults to None. 261 262 tx_power_list : [N], float 263 The :math:`i`-th element defines the power of the :math:`i`-th precoding vector. 264 265 Output 266 ------- 267 : [N,M] tf.complex 268 Normalized antenna coefficients. 269 """ 270 if dtype is None: 271 dtype = precoding_vec.dtype 272 273 if len(precoding_vec.shape)==1: 274 precoding_vec = precoding_vec[tf.newaxis, :] 275 276 if tx_power_list is None: 277 # By default, power is normalized to 1 278 tx_power_list = [1] * precoding_vec.shape[0] 279 280 precoding_vec_norm = tf.cast(tf.norm(precoding_vec, axis=1), dtype)[ 281 :, tf.newaxis] 282 tx_power = tf.constant(tx_power_list, dtype=dtype)[:, tf.newaxis] 283 284 # Normalize the power of each row 285 precoding_vec = tf.math.multiply(tf.math.divide( 286 precoding_vec, precoding_vec_norm), tx_power) 287 288 return precoding_vec