equalization.py (11422B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Classes and functions related to MIMO channel equalization""" 6 7 import tensorflow as tf 8 from sionna.utils import expand_to_rank, matrix_inv, matrix_pinv 9 from sionna.mimo.utils import whiten_channel 10 11 12 def lmmse_equalizer(y, h, s, whiten_interference=True): 13 # pylint: disable=line-too-long 14 r"""MIMO LMMSE Equalizer 15 16 This function implements LMMSE equalization for a MIMO link, assuming the 17 following model: 18 19 .. math:: 20 21 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 22 23 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 24 :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, 25 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 26 and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. 27 It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, 28 :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and 29 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. 30 31 The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as 32 (Lemma B.19) [BHS2017]_ : 33 34 .. math:: 35 36 \hat{\mathbf{x}} = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}\mathbf{G}\mathbf{y} 37 38 where 39 40 .. math:: 41 42 \mathbf{G} = \mathbf{H}^{\mathsf{H}} \left(\mathbf{H}\mathbf{H}^{\mathsf{H}} + \mathbf{S}\right)^{-1}. 43 44 This leads to the post-equalized per-symbol model: 45 46 .. math:: 47 48 \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 49 50 where the variances :math:`\sigma^2_k` of the effective residual noise 51 terms :math:`e_k` are given by the diagonal elements of 52 53 .. math:: 54 55 \mathop{\text{diag}}\left(\mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right]\right) 56 = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H} \right)^{-1} - \mathbf{I}. 57 58 Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}` 59 is important for the :class:`~sionna.mapping.Demapper` although it does 60 not change the signal-to-noise ratio. 61 62 The function returns :math:`\hat{\mathbf{x}}` and 63 :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. 64 65 Input 66 ----- 67 y : [...,M], tf.complex 68 1+D tensor containing the received signals. 69 70 h : [...,M,K], tf.complex 71 2+D tensor containing the channel matrices. 72 73 s : [...,M,M], tf.complex 74 2+D tensor containing the noise covariance matrices. 75 76 whiten_interference : bool 77 If `True` (default), the interference is first whitened before equalization. 78 In this case, an alternative expression for the receive filter is used that 79 can be numerically more stable. Defaults to `True`. 80 81 Output 82 ------ 83 x_hat : [...,K], tf.complex 84 1+D tensor representing the estimated symbol vectors. 85 86 no_eff : tf.float 87 Tensor of the same shape as ``x_hat`` containing the effective noise 88 variance estimates. 89 90 Note 91 ---- 92 If you want to use this function in Graph mode with XLA, i.e., within 93 a function that is decorated with ``@tf.function(jit_compile=True)``, 94 you must set ``sionna.Config.xla_compat=true``. 95 See :py:attr:`~sionna.Config.xla_compat`. 96 """ 97 98 # We assume the model: 99 # y = Hx + n, where E[nn']=S. 100 # E[x]=E[n]=0 101 # 102 # The LMMSE estimate of x is given as: 103 # x_hat = diag(GH)^(-1)Gy 104 # with G=H'(HH'+S)^(-1). 105 # 106 # This leads us to the per-symbol model; 107 # 108 # x_hat_k = x_k + e_k 109 # 110 # The elements of the residual noise vector e have variance: 111 # diag(E[ee']) = diag(GH)^(-1) - I 112 if not whiten_interference: 113 # Compute G 114 g = tf.matmul(h, h, adjoint_b=True) + s 115 g = tf.matmul(h, matrix_inv(g), adjoint_a=True) 116 117 else: 118 # Whiten channel 119 y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=unbalanced-tuple-unpacking 120 121 # Compute G 122 i = expand_to_rank(tf.eye(h.shape[-1], dtype=s.dtype), tf.rank(s), 0) 123 g = tf.matmul(h, h, adjoint_a=True) + i 124 g = tf.matmul(matrix_inv(g), h, adjoint_b=True) 125 126 # Compute Gy 127 y = tf.expand_dims(y, -1) 128 gy = tf.squeeze(tf.matmul(g, y), axis=-1) 129 130 # Compute GH 131 gh = tf.matmul(g, h) 132 133 # Compute diag(GH) 134 d = tf.linalg.diag_part(gh) 135 136 # Compute x_hat 137 x_hat = gy/d 138 139 # Compute residual error variance 140 one = tf.cast(1, dtype=d.dtype) 141 no_eff = tf.math.real(one/d - one) 142 143 return x_hat, no_eff 144 145 def zf_equalizer(y, h, s): 146 # pylint: disable=line-too-long 147 r"""MIMO ZF Equalizer 148 149 This function implements zero-forcing (ZF) equalization for a MIMO link, assuming the 150 following model: 151 152 .. math:: 153 154 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 155 156 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 157 :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, 158 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 159 and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. 160 It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, 161 :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and 162 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. 163 164 The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as 165 (Eq. 4.10) [BHS2017]_ : 166 167 .. math:: 168 169 \hat{\mathbf{x}} = \mathbf{G}\mathbf{y} 170 171 where 172 173 .. math:: 174 175 \mathbf{G} = \left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}. 176 177 This leads to the post-equalized per-symbol model: 178 179 .. math:: 180 181 \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 182 183 where the variances :math:`\sigma^2_k` of the effective residual noise 184 terms :math:`e_k` are given by the diagonal elements of the matrix 185 186 .. math:: 187 188 \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right] 189 = \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}. 190 191 The function returns :math:`\hat{\mathbf{x}}` and 192 :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. 193 194 Input 195 ----- 196 y : [...,M], tf.complex 197 1+D tensor containing the received signals. 198 199 h : [...,M,K], tf.complex 200 2+D tensor containing the channel matrices. 201 202 s : [...,M,M], tf.complex 203 2+D tensor containing the noise covariance matrices. 204 205 Output 206 ------ 207 x_hat : [...,K], tf.complex 208 1+D tensor representing the estimated symbol vectors. 209 210 no_eff : tf.float 211 Tensor of the same shape as ``x_hat`` containing the effective noise 212 variance estimates. 213 214 Note 215 ---- 216 If you want to use this function in Graph mode with XLA, i.e., within 217 a function that is decorated with ``@tf.function(jit_compile=True)``, 218 you must set ``sionna.Config.xla_compat=true``. 219 See :py:attr:`~sionna.Config.xla_compat`. 220 """ 221 222 # We assume the model: 223 # y = Hx + n, where E[nn']=S. 224 # E[x]=E[n]=0 225 # 226 # The ZF estimate of x is given as: 227 # x_hat = Gy 228 # with G=(H'H')^(-1)H'. 229 # 230 # This leads us to the per-symbol model; 231 # 232 # x_hat_k = x_k + e_k 233 # 234 # The elements of the residual noise vector e have variance: 235 # E[ee'] = GSG' 236 237 # Compute G 238 g = matrix_pinv(h) 239 240 # Compute x_hat 241 y = tf.expand_dims(y, -1) 242 x_hat = tf.squeeze(tf.matmul(g, y), axis=-1) 243 244 # Compute residual error variance 245 gsg = tf.matmul(tf.matmul(g, s), g, adjoint_b=True) 246 no_eff = tf.math.real(tf.linalg.diag_part(gsg)) 247 248 return x_hat, no_eff 249 250 def mf_equalizer(y, h, s): 251 # pylint: disable=line-too-long 252 r"""MIMO MF Equalizer 253 254 This function implements matched filter (MF) equalization for a 255 MIMO link, assuming the following model: 256 257 .. math:: 258 259 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 260 261 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 262 :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, 263 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 264 and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. 265 It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, 266 :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and 267 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. 268 269 The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as 270 (Eq. 4.11) [BHS2017]_ : 271 272 .. math:: 273 274 \hat{\mathbf{x}} = \mathbf{G}\mathbf{y} 275 276 where 277 278 .. math:: 279 280 \mathbf{G} = \mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}. 281 282 This leads to the post-equalized per-symbol model: 283 284 .. math:: 285 286 \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 287 288 where the variances :math:`\sigma^2_k` of the effective residual noise 289 terms :math:`e_k` are given by the diagonal elements of the matrix 290 291 .. math:: 292 293 \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right] 294 = \left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)\left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)^{\mathsf{H}} + \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}. 295 296 Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}` 297 in the definition of :math:`\mathbf{G}` 298 is important for the :class:`~sionna.mapping.Demapper` although it does 299 not change the signal-to-noise ratio. 300 301 The function returns :math:`\hat{\mathbf{x}}` and 302 :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. 303 304 Input 305 ----- 306 y : [...,M], tf.complex 307 1+D tensor containing the received signals. 308 309 h : [...,M,K], tf.complex 310 2+D tensor containing the channel matrices. 311 312 s : [...,M,M], tf.complex 313 2+D tensor containing the noise covariance matrices. 314 315 Output 316 ------ 317 x_hat : [...,K], tf.complex 318 1+D tensor representing the estimated symbol vectors. 319 320 no_eff : tf.float 321 Tensor of the same shape as ``x_hat`` containing the effective noise 322 variance estimates. 323 """ 324 325 # We assume the model: 326 # y = Hx + n, where E[nn']=S. 327 # E[x]=E[n]=0 328 # 329 # The MF estimate of x is given as: 330 # x_hat = Gy 331 # with G=diag(H'H)^-1 H'. 332 # 333 # This leads us to the per-symbol model; 334 # 335 # x_hat_k = x_k + e_k 336 # 337 # The elements of the residual noise vector e have variance: 338 # E[ee'] = (I-GH)(I-GH)' + GSG' 339 340 # Compute G 341 hth = tf.matmul(h, h, adjoint_a=True) 342 d = tf.linalg.diag(tf.cast(1, h.dtype)/tf.linalg.diag_part(hth)) 343 g = tf.matmul(d, h, adjoint_b=True) 344 345 # Compute x_hat 346 y = tf.expand_dims(y, -1) 347 x_hat = tf.squeeze(tf.matmul(g, y), axis=-1) 348 349 # Compute residual error variance 350 gsg = tf.matmul(tf.matmul(g, s), g, adjoint_b=True) 351 gh = tf.matmul(g, h) 352 i = expand_to_rank(tf.eye(gsg.shape[-2], dtype=gsg.dtype), tf.rank(gsg), 0) 353 354 no_eff = tf.abs(tf.linalg.diag_part(tf.matmul(i-gh, i-gh, adjoint_b=True) + gsg)) 355 return x_hat, no_eff