spatial_correlation.py (6908B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Various classes for spatially correlated flat-fading channels.""" 6 7 from abc import ABC, abstractmethod 8 import tensorflow as tf 9 from tensorflow.experimental.numpy import swapaxes 10 from sionna.utils import expand_to_rank, matrix_sqrt 11 12 class SpatialCorrelation(ABC): 13 # pylint: disable=line-too-long 14 r"""Abstract class that defines an interface for spatial correlation functions. 15 16 The :class:`~sionna.channel.FlatFadingChannel` model can be configured with a 17 spatial correlation model. 18 19 Input 20 ----- 21 h : tf.complex 22 Tensor of arbitrary shape containing spatially uncorrelated 23 channel coefficients 24 25 Output 26 ------ 27 h_corr : tf.complex 28 Tensor of the same shape and dtype as ``h`` containing the spatially 29 correlated channel coefficients. 30 """ 31 @abstractmethod 32 def __call__(self, h, *args, **kwargs): 33 return NotImplemented 34 35 class KroneckerModel(SpatialCorrelation): 36 # pylint: disable=line-too-long 37 r"""Kronecker model for spatial correlation. 38 39 Given a batch of matrices :math:`\mathbf{H}\in\mathbb{C}^{M\times K}`, 40 :math:`\mathbf{R}_\text{tx}\in\mathbb{C}^{K\times K}`, and 41 :math:`\mathbf{R}_\text{rx}\in\mathbb{C}^{M\times M}`, this function 42 will generate the following output: 43 44 .. math:: 45 46 \mathbf{H}_\text{corr} = \mathbf{R}^{\frac12}_\text{rx} \mathbf{H} \mathbf{R}^{\frac12}_\text{tx} 47 48 Note that :math:`\mathbf{R}_\text{tx}\in\mathbb{C}^{K\times K}` and :math:`\mathbf{R}_\text{rx}\in\mathbb{C}^{M\times M}` 49 must be positive semi-definite, such as the ones generated by 50 :meth:`~sionna.channel.exp_corr_mat`. 51 52 Parameters 53 ---------- 54 r_tx : [..., K, K], tf.complex 55 Tensor containing the transmit correlation matrices. If 56 the rank of ``r_tx`` is smaller than that of the input ``h``, 57 it will be broadcast. 58 59 r_rx : [..., M, M], tf.complex 60 Tensor containing the receive correlation matrices. If 61 the rank of ``r_rx`` is smaller than that of the input ``h``, 62 it will be broadcast. 63 64 Input 65 ----- 66 h : [..., M, K], tf.complex 67 Tensor containing spatially uncorrelated 68 channel coeffficients. 69 70 Output 71 ------ 72 h_corr : [..., M, K], tf.complex 73 Tensor containing the spatially 74 correlated channel coefficients. 75 """ 76 def __init__(self, r_tx=None, r_rx=None): 77 super().__init__() 78 self.r_tx = r_tx 79 self.r_rx = r_rx 80 81 @property 82 def r_tx(self): 83 r"""Tensor containing the transmit correlation matrices. 84 85 Note 86 ---- 87 If you want to set this property in Graph mode with XLA, i.e., within 88 a function that is decorated with ``@tf.function(jit_compile=True)``, 89 you must set ``sionna.Config.xla_compat=true``. 90 See :py:attr:`~sionna.Config.xla_compat`. 91 """ 92 return self._r_tx 93 94 @r_tx.setter 95 def r_tx(self, value): 96 self._r_tx = value 97 if self._r_tx is not None: 98 self._r_tx_sqrt = matrix_sqrt(value) 99 else: 100 self._r_tx_sqrt = None 101 102 @property 103 def r_rx(self): 104 r"""Tensor containing the receive correlation matrices. 105 106 Note 107 ---- 108 If you want to set this property in Graph mode with XLA, i.e., within 109 a function that is decorated with ``@tf.function(jit_compile=True)``, 110 you must set ``sionna.Config.xla_compat=true``. 111 See :py:attr:`~sionna.Config.xla_compat`. 112 """ 113 return self._r_rx 114 115 @r_rx.setter 116 def r_rx(self, value): 117 self._r_rx = value 118 if self._r_rx is not None: 119 self._r_rx_sqrt = matrix_sqrt(value) 120 else: 121 self._r_rx_sqrt = None 122 123 def __call__(self, h): 124 if self._r_tx_sqrt is not None: 125 r_tx_sqrt = expand_to_rank(self._r_tx_sqrt, tf.rank(h), 0) 126 h = tf.matmul(h, r_tx_sqrt, adjoint_b=True) 127 128 if self._r_rx_sqrt is not None: 129 r_rx_sqrt = expand_to_rank(self._r_rx_sqrt, tf.rank(h), 0) 130 h = tf.matmul(r_rx_sqrt, h) 131 132 return h 133 134 class PerColumnModel(SpatialCorrelation): 135 # pylint: disable=line-too-long 136 r"""Per-column model for spatial correlation. 137 138 Given a batch of matrices :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` 139 and correlation matrices :math:`\mathbf{R}_k\in\mathbb{C}^{M\times M}, k=1,\dots,K`, 140 this function will generate the output :math:`\mathbf{H}_\text{corr}\in\mathbb{C}^{M\times K}`, 141 with columns 142 143 .. math:: 144 145 \mathbf{h}^\text{corr}_k = \mathbf{R}^{\frac12}_k \mathbf{h}_k,\quad k=1, \dots, K 146 147 where :math:`\mathbf{h}_k` is the kth column of :math:`\mathbf{H}`. 148 Note that all :math:`\mathbf{R}_k\in\mathbb{C}^{M\times M}` must 149 be positive semi-definite, such as the ones generated 150 by :meth:`~sionna.channel.one_ring_corr_mat`. 151 152 This model is typically used to simulate a MIMO channel between multiple 153 single-antenna users and a base station with multiple antennas. 154 The resulting SIMO channel for each user has a different spatial correlation. 155 156 Parameters 157 ---------- 158 r_rx : [..., M, M], tf.complex 159 Tensor containing the receive correlation matrices. If 160 the rank of ``r_rx`` is smaller than that of the input ``h``, 161 it will be broadcast. For a typically use of this model, ``r_rx`` 162 has shape [..., K, M, M], i.e., a different correlation matrix for each 163 column of ``h``. 164 165 Input 166 ----- 167 h : [..., M, K], tf.complex 168 Tensor containing spatially uncorrelated 169 channel coeffficients. 170 171 Output 172 ------ 173 h_corr : [..., M, K], tf.complex 174 Tensor containing the spatially 175 correlated channel coefficients. 176 """ 177 def __init__(self, r_rx): 178 super().__init__() 179 self.r_rx = r_rx 180 181 @property 182 def r_rx(self): 183 """Tensor containing the receive correlation matrices. 184 185 Note 186 ---- 187 If you want to set this property in Graph mode with XLA, i.e., within 188 a function that is decorated with ``@tf.function(jit_compile=True)``, 189 you must set ``sionna.Config.xla_compat=true``. 190 See :py:attr:`~sionna.Config.xla_compat`. 191 """ 192 193 return self._r_rx 194 195 @r_rx.setter 196 def r_rx(self, value): 197 self._r_rx = value 198 if self._r_rx is not None: 199 self._r_rx_sqrt = matrix_sqrt(value) 200 201 def __call__(self, h): 202 if self._r_rx is not None: 203 h = swapaxes(h, -2, -1) 204 h = tf.expand_dims(h, -1) 205 r_rx_sqrt = expand_to_rank(self._r_rx_sqrt, tf.rank(h), 0) 206 h = tf.matmul(r_rx_sqrt, h) 207 h = tf.squeeze(h, -1) 208 h = swapaxes(h, -2, -1) 209 210 return h