anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

spatial_correlation.py (6908B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Various classes for spatially correlated flat-fading channels."""
      6 
      7 from abc import ABC, abstractmethod
      8 import tensorflow as tf
      9 from tensorflow.experimental.numpy import swapaxes
     10 from sionna.utils import expand_to_rank, matrix_sqrt
     11 
     12 class SpatialCorrelation(ABC):
     13     # pylint: disable=line-too-long
     14     r"""Abstract class that defines an interface for spatial correlation functions.
     15 
     16     The :class:`~sionna.channel.FlatFadingChannel` model can be configured with a
     17     spatial correlation model.
     18 
     19     Input
     20     -----
     21     h : tf.complex
     22         Tensor of arbitrary shape containing spatially uncorrelated
     23         channel coefficients
     24 
     25     Output
     26     ------
     27     h_corr : tf.complex
     28         Tensor of the same shape and dtype as ``h`` containing the spatially
     29         correlated channel coefficients.
     30     """
     31     @abstractmethod
     32     def __call__(self, h, *args, **kwargs):
     33         return NotImplemented
     34 
     35 class KroneckerModel(SpatialCorrelation):
     36     # pylint: disable=line-too-long
     37     r"""Kronecker model for spatial correlation.
     38 
     39     Given a batch of matrices :math:`\mathbf{H}\in\mathbb{C}^{M\times K}`,
     40     :math:`\mathbf{R}_\text{tx}\in\mathbb{C}^{K\times K}`, and
     41     :math:`\mathbf{R}_\text{rx}\in\mathbb{C}^{M\times M}`, this function
     42     will generate the following output:
     43 
     44     .. math::
     45 
     46         \mathbf{H}_\text{corr} = \mathbf{R}^{\frac12}_\text{rx} \mathbf{H} \mathbf{R}^{\frac12}_\text{tx}
     47 
     48     Note that :math:`\mathbf{R}_\text{tx}\in\mathbb{C}^{K\times K}` and :math:`\mathbf{R}_\text{rx}\in\mathbb{C}^{M\times M}`
     49     must be positive semi-definite, such as the ones generated by
     50     :meth:`~sionna.channel.exp_corr_mat`.
     51 
     52     Parameters
     53     ----------
     54     r_tx : [..., K, K], tf.complex
     55         Tensor containing the transmit correlation matrices. If
     56         the rank of ``r_tx`` is smaller than that of the input ``h``,
     57         it will be broadcast.
     58 
     59     r_rx : [..., M, M], tf.complex
     60         Tensor containing the receive correlation matrices. If
     61         the rank of ``r_rx`` is smaller than that of the input ``h``,
     62         it will be broadcast.
     63 
     64     Input
     65     -----
     66     h : [..., M, K], tf.complex
     67         Tensor containing spatially uncorrelated
     68         channel coeffficients.
     69 
     70     Output
     71     ------
     72     h_corr : [..., M, K], tf.complex
     73         Tensor containing the spatially
     74         correlated channel coefficients.
     75     """
     76     def __init__(self, r_tx=None, r_rx=None):
     77         super().__init__()
     78         self.r_tx = r_tx
     79         self.r_rx = r_rx
     80 
     81     @property
     82     def r_tx(self):
     83         r"""Tensor containing the transmit correlation matrices.
     84 
     85         Note
     86         ----
     87         If you want to set this property in Graph mode with XLA, i.e., within
     88         a function that is decorated with ``@tf.function(jit_compile=True)``,
     89         you must set ``sionna.Config.xla_compat=true``.
     90         See :py:attr:`~sionna.Config.xla_compat`.
     91         """
     92         return self._r_tx
     93 
     94     @r_tx.setter
     95     def r_tx(self, value):
     96         self._r_tx = value
     97         if self._r_tx is not None:
     98             self._r_tx_sqrt = matrix_sqrt(value)
     99         else:
    100             self._r_tx_sqrt = None
    101 
    102     @property
    103     def r_rx(self):
    104         r"""Tensor containing the receive correlation matrices.
    105 
    106         Note
    107         ----
    108         If you want to set this property in Graph mode with XLA, i.e., within
    109         a function that is decorated with ``@tf.function(jit_compile=True)``,
    110         you must set ``sionna.Config.xla_compat=true``.
    111         See :py:attr:`~sionna.Config.xla_compat`.
    112         """
    113         return self._r_rx
    114 
    115     @r_rx.setter
    116     def r_rx(self, value):
    117         self._r_rx = value
    118         if self._r_rx is not None:
    119             self._r_rx_sqrt = matrix_sqrt(value)
    120         else:
    121             self._r_rx_sqrt = None
    122 
    123     def __call__(self, h):
    124         if self._r_tx_sqrt is not None:
    125             r_tx_sqrt = expand_to_rank(self._r_tx_sqrt, tf.rank(h), 0)
    126             h = tf.matmul(h, r_tx_sqrt, adjoint_b=True)
    127 
    128         if self._r_rx_sqrt is not None:
    129             r_rx_sqrt = expand_to_rank(self._r_rx_sqrt, tf.rank(h), 0)
    130             h = tf.matmul(r_rx_sqrt, h)
    131 
    132         return h
    133 
    134 class PerColumnModel(SpatialCorrelation):
    135         # pylint: disable=line-too-long
    136     r"""Per-column model for spatial correlation.
    137 
    138     Given a batch of matrices :math:`\mathbf{H}\in\mathbb{C}^{M\times K}`
    139     and correlation matrices :math:`\mathbf{R}_k\in\mathbb{C}^{M\times M}, k=1,\dots,K`,
    140     this function will generate the output :math:`\mathbf{H}_\text{corr}\in\mathbb{C}^{M\times K}`,
    141     with columns
    142 
    143     .. math::
    144 
    145         \mathbf{h}^\text{corr}_k = \mathbf{R}^{\frac12}_k \mathbf{h}_k,\quad k=1, \dots, K
    146 
    147     where :math:`\mathbf{h}_k` is the kth column of :math:`\mathbf{H}`.
    148     Note that all :math:`\mathbf{R}_k\in\mathbb{C}^{M\times M}` must
    149     be positive semi-definite, such as the ones generated
    150     by :meth:`~sionna.channel.one_ring_corr_mat`.
    151 
    152     This model is typically used to simulate a MIMO channel between multiple
    153     single-antenna users and a base station with multiple antennas.
    154     The resulting SIMO channel for each user has a different spatial correlation.
    155 
    156     Parameters
    157     ----------
    158     r_rx : [..., M, M], tf.complex
    159         Tensor containing the receive correlation matrices. If
    160         the rank of ``r_rx`` is smaller than that of the input ``h``,
    161         it will be broadcast. For a typically use of this model, ``r_rx``
    162         has shape [..., K, M, M], i.e., a different correlation matrix for each
    163         column of ``h``.
    164 
    165     Input
    166     -----
    167     h : [..., M, K], tf.complex
    168         Tensor containing spatially uncorrelated
    169         channel coeffficients.
    170 
    171     Output
    172     ------
    173     h_corr : [..., M, K], tf.complex
    174         Tensor containing the spatially
    175         correlated channel coefficients.
    176     """
    177     def __init__(self, r_rx):
    178         super().__init__()
    179         self.r_rx = r_rx
    180 
    181     @property
    182     def r_rx(self):
    183         """Tensor containing the receive correlation matrices.
    184 
    185         Note
    186         ----
    187         If you want to set this property in Graph mode with XLA, i.e., within
    188         a function that is decorated with ``@tf.function(jit_compile=True)``,
    189         you must set ``sionna.Config.xla_compat=true``.
    190         See :py:attr:`~sionna.Config.xla_compat`.
    191         """
    192 
    193         return self._r_rx
    194 
    195     @r_rx.setter
    196     def r_rx(self, value):
    197         self._r_rx = value
    198         if self._r_rx is not None:
    199             self._r_rx_sqrt = matrix_sqrt(value)
    200 
    201     def __call__(self, h):
    202         if self._r_rx is not None:
    203             h = swapaxes(h, -2, -1)
    204             h = tf.expand_dims(h, -1)
    205             r_rx_sqrt = expand_to_rank(self._r_rx_sqrt, tf.rank(h), 0)
    206             h = tf.matmul(r_rx_sqrt, h)
    207             h = tf.squeeze(h, -1)
    208             h = swapaxes(h, -2, -1)
    209 
    210         return h