anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

equalization.py (11422B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Classes and functions related to MIMO channel equalization"""
      6 
      7 import tensorflow as tf
      8 from sionna.utils import expand_to_rank, matrix_inv, matrix_pinv
      9 from sionna.mimo.utils import whiten_channel
     10 
     11 
     12 def lmmse_equalizer(y, h, s, whiten_interference=True):
     13     # pylint: disable=line-too-long
     14     r"""MIMO LMMSE Equalizer
     15 
     16     This function implements LMMSE equalization for a MIMO link, assuming the
     17     following model:
     18 
     19     .. math::
     20 
     21         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
     22 
     23     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
     24     :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols,
     25     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
     26     and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector.
     27     It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`,
     28     :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and
     29     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`.
     30 
     31     The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as
     32     (Lemma B.19) [BHS2017]_ :
     33 
     34     .. math::
     35 
     36         \hat{\mathbf{x}} = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}\mathbf{G}\mathbf{y}
     37 
     38     where
     39 
     40     .. math::
     41 
     42         \mathbf{G} = \mathbf{H}^{\mathsf{H}} \left(\mathbf{H}\mathbf{H}^{\mathsf{H}} + \mathbf{S}\right)^{-1}.
     43 
     44     This leads to the post-equalized per-symbol model:
     45 
     46     .. math::
     47 
     48         \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1
     49 
     50     where the variances :math:`\sigma^2_k` of the effective residual noise
     51     terms :math:`e_k` are given by the diagonal elements of
     52 
     53     .. math::
     54 
     55         \mathop{\text{diag}}\left(\mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right]\right)
     56         = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H} \right)^{-1} - \mathbf{I}.
     57 
     58     Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}`
     59     is important for the :class:`~sionna.mapping.Demapper` although it does
     60     not change the signal-to-noise ratio.
     61 
     62     The function returns :math:`\hat{\mathbf{x}}` and
     63     :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`.
     64 
     65     Input
     66     -----
     67     y : [...,M], tf.complex
     68         1+D tensor containing the received signals.
     69 
     70     h : [...,M,K], tf.complex
     71         2+D tensor containing the channel matrices.
     72 
     73     s : [...,M,M], tf.complex
     74         2+D tensor containing the noise covariance matrices.
     75 
     76     whiten_interference : bool
     77         If `True` (default), the interference is first whitened before equalization.
     78         In this case, an alternative expression for the receive filter is used that
     79         can be numerically more stable. Defaults to `True`.
     80 
     81     Output
     82     ------
     83     x_hat : [...,K], tf.complex
     84         1+D tensor representing the estimated symbol vectors.
     85 
     86     no_eff : tf.float
     87         Tensor of the same shape as ``x_hat`` containing the effective noise
     88         variance estimates.
     89 
     90     Note
     91     ----
     92     If you want to use this function in Graph mode with XLA, i.e., within
     93     a function that is decorated with ``@tf.function(jit_compile=True)``,
     94     you must set ``sionna.Config.xla_compat=true``.
     95     See :py:attr:`~sionna.Config.xla_compat`.
     96     """
     97 
     98     # We assume the model:
     99     # y = Hx + n, where E[nn']=S.
    100     # E[x]=E[n]=0
    101     #
    102     # The LMMSE estimate of x is given as:
    103     # x_hat = diag(GH)^(-1)Gy
    104     # with G=H'(HH'+S)^(-1).
    105     #
    106     # This leads us to the per-symbol model;
    107     #
    108     # x_hat_k = x_k + e_k
    109     #
    110     # The elements of the residual noise vector e have variance:
    111     # diag(E[ee']) = diag(GH)^(-1) - I
    112     if not whiten_interference:
    113         # Compute G
    114         g = tf.matmul(h, h, adjoint_b=True) + s
    115         g = tf.matmul(h, matrix_inv(g), adjoint_a=True)
    116 
    117     else:
    118         # Whiten channel
    119         y, h  = whiten_channel(y, h, s, return_s=False) # pylint: disable=unbalanced-tuple-unpacking
    120 
    121         # Compute G
    122         i = expand_to_rank(tf.eye(h.shape[-1], dtype=s.dtype), tf.rank(s), 0)
    123         g = tf.matmul(h, h, adjoint_a=True) + i
    124         g = tf.matmul(matrix_inv(g), h, adjoint_b=True)
    125 
    126     # Compute Gy
    127     y = tf.expand_dims(y, -1)
    128     gy = tf.squeeze(tf.matmul(g, y), axis=-1)
    129 
    130     # Compute GH
    131     gh = tf.matmul(g, h)
    132 
    133     # Compute diag(GH)
    134     d = tf.linalg.diag_part(gh)
    135 
    136     # Compute x_hat
    137     x_hat = gy/d
    138 
    139     # Compute residual error variance
    140     one = tf.cast(1, dtype=d.dtype)
    141     no_eff = tf.math.real(one/d - one)
    142 
    143     return x_hat, no_eff
    144 
    145 def zf_equalizer(y, h, s):
    146     # pylint: disable=line-too-long
    147     r"""MIMO ZF Equalizer
    148 
    149     This function implements zero-forcing (ZF) equalization for a MIMO link, assuming the
    150     following model:
    151 
    152     .. math::
    153 
    154         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    155 
    156     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    157     :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols,
    158     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
    159     and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector.
    160     It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`,
    161     :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and
    162     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`.
    163 
    164     The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as
    165     (Eq. 4.10) [BHS2017]_ :
    166 
    167     .. math::
    168 
    169         \hat{\mathbf{x}} = \mathbf{G}\mathbf{y}
    170 
    171     where
    172 
    173     .. math::
    174 
    175         \mathbf{G} = \left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}.
    176 
    177     This leads to the post-equalized per-symbol model:
    178 
    179     .. math::
    180 
    181         \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1
    182 
    183     where the variances :math:`\sigma^2_k` of the effective residual noise
    184     terms :math:`e_k` are given by the diagonal elements of the matrix
    185 
    186     .. math::
    187 
    188         \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right]
    189         = \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}.
    190 
    191     The function returns :math:`\hat{\mathbf{x}}` and
    192     :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`.
    193 
    194     Input
    195     -----
    196     y : [...,M], tf.complex
    197         1+D tensor containing the received signals.
    198 
    199     h : [...,M,K], tf.complex
    200         2+D tensor containing the channel matrices.
    201 
    202     s : [...,M,M], tf.complex
    203         2+D tensor containing the noise covariance matrices.
    204 
    205     Output
    206     ------
    207     x_hat : [...,K], tf.complex
    208         1+D tensor representing the estimated symbol vectors.
    209 
    210     no_eff : tf.float
    211         Tensor of the same shape as ``x_hat`` containing the effective noise
    212         variance estimates.
    213 
    214     Note
    215     ----
    216     If you want to use this function in Graph mode with XLA, i.e., within
    217     a function that is decorated with ``@tf.function(jit_compile=True)``,
    218     you must set ``sionna.Config.xla_compat=true``.
    219     See :py:attr:`~sionna.Config.xla_compat`.
    220     """
    221 
    222     # We assume the model:
    223     # y = Hx + n, where E[nn']=S.
    224     # E[x]=E[n]=0
    225     #
    226     # The ZF estimate of x is given as:
    227     # x_hat = Gy
    228     # with G=(H'H')^(-1)H'.
    229     #
    230     # This leads us to the per-symbol model;
    231     #
    232     # x_hat_k = x_k + e_k
    233     #
    234     # The elements of the residual noise vector e have variance:
    235     # E[ee'] = GSG'
    236 
    237     # Compute G
    238     g = matrix_pinv(h)
    239 
    240     # Compute x_hat
    241     y = tf.expand_dims(y, -1)
    242     x_hat = tf.squeeze(tf.matmul(g, y), axis=-1)
    243 
    244     # Compute residual error variance
    245     gsg = tf.matmul(tf.matmul(g, s), g, adjoint_b=True)
    246     no_eff = tf.math.real(tf.linalg.diag_part(gsg))
    247 
    248     return x_hat, no_eff
    249 
    250 def mf_equalizer(y, h, s):
    251     # pylint: disable=line-too-long
    252     r"""MIMO MF Equalizer
    253 
    254     This function implements matched filter (MF) equalization for a
    255     MIMO link, assuming the following model:
    256 
    257     .. math::
    258 
    259         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    260 
    261     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    262     :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols,
    263     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
    264     and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector.
    265     It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`,
    266     :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and
    267     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`.
    268 
    269     The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as
    270     (Eq. 4.11) [BHS2017]_ :
    271 
    272     .. math::
    273 
    274         \hat{\mathbf{x}} = \mathbf{G}\mathbf{y}
    275 
    276     where
    277 
    278     .. math::
    279 
    280         \mathbf{G} = \mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}.
    281 
    282     This leads to the post-equalized per-symbol model:
    283 
    284     .. math::
    285 
    286         \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1
    287 
    288     where the variances :math:`\sigma^2_k` of the effective residual noise
    289     terms :math:`e_k` are given by the diagonal elements of the matrix
    290 
    291     .. math::
    292 
    293         \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right]
    294         = \left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)\left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)^{\mathsf{H}} + \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}.
    295 
    296     Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}`
    297     in the definition of :math:`\mathbf{G}`
    298     is important for the :class:`~sionna.mapping.Demapper` although it does
    299     not change the signal-to-noise ratio.
    300 
    301     The function returns :math:`\hat{\mathbf{x}}` and
    302     :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`.
    303 
    304     Input
    305     -----
    306     y : [...,M], tf.complex
    307         1+D tensor containing the received signals.
    308 
    309     h : [...,M,K], tf.complex
    310         2+D tensor containing the channel matrices.
    311 
    312     s : [...,M,M], tf.complex
    313         2+D tensor containing the noise covariance matrices.
    314 
    315     Output
    316     ------
    317     x_hat : [...,K], tf.complex
    318         1+D tensor representing the estimated symbol vectors.
    319 
    320     no_eff : tf.float
    321         Tensor of the same shape as ``x_hat`` containing the effective noise
    322         variance estimates.
    323     """
    324 
    325     # We assume the model:
    326     # y = Hx + n, where E[nn']=S.
    327     # E[x]=E[n]=0
    328     #
    329     # The MF estimate of x is given as:
    330     # x_hat = Gy
    331     # with G=diag(H'H)^-1 H'.
    332     #
    333     # This leads us to the per-symbol model;
    334     #
    335     # x_hat_k = x_k + e_k
    336     #
    337     # The elements of the residual noise vector e have variance:
    338     # E[ee'] = (I-GH)(I-GH)' + GSG'
    339 
    340     # Compute G
    341     hth = tf.matmul(h, h, adjoint_a=True)
    342     d = tf.linalg.diag(tf.cast(1, h.dtype)/tf.linalg.diag_part(hth))
    343     g = tf.matmul(d, h, adjoint_b=True)
    344 
    345     # Compute x_hat
    346     y = tf.expand_dims(y, -1)
    347     x_hat = tf.squeeze(tf.matmul(g, y), axis=-1)
    348 
    349     # Compute residual error variance
    350     gsg = tf.matmul(tf.matmul(g, s), g, adjoint_b=True)
    351     gh = tf.matmul(g, h)
    352     i = expand_to_rank(tf.eye(gsg.shape[-2], dtype=gsg.dtype), tf.rank(gsg), 0)
    353 
    354     no_eff = tf.abs(tf.linalg.diag_part(tf.matmul(i-gh, i-gh, adjoint_b=True) + gsg))
    355     return x_hat, no_eff