anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

channel_estimation.py (89915B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Functions related to OFDM channel estimation"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 import numpy as np
     10 from sionna.channel.tr38901 import models
     11 from sionna.utils import flatten_last_dims, expand_to_rank, matrix_inv
     12 from sionna.ofdm import ResourceGrid, RemoveNulledSubcarriers
     13 from sionna import PI, SPEED_OF_LIGHT
     14 from scipy.special import jv
     15 import itertools
     16 from abc import ABC, abstractmethod
     17 import json
     18 from importlib_resources import files
     19 
     20 class BaseChannelEstimator(ABC, Layer):
     21     # pylint: disable=line-too-long
     22     r"""BaseChannelEstimator(resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs)
     23 
     24     Abstract layer for implementing an OFDM channel estimator.
     25 
     26     Any layer that implements an OFDM channel estimator must implement this
     27     class and its
     28     :meth:`~sionna.ofdm.BaseChannelEstimator.estimate_at_pilot_locations`
     29     abstract method.
     30 
     31     This class extracts the pilots from the received resource grid ``y``, calls
     32     the :meth:`~sionna.ofdm.BaseChannelEstimator.estimate_at_pilot_locations`
     33     method to estimate the channel for the pilot-carrying resource elements,
     34     and then interpolates the channel to compute channel estimates for the
     35     data-carrying resouce elements using the interpolation method specified by
     36     ``interpolation_type`` or the ``interpolator`` object.
     37 
     38     Parameters
     39     ----------
     40     resource_grid : ResourceGrid
     41         An instance of :class:`~sionna.ofdm.ResourceGrid`.
     42 
     43     interpolation_type : One of ["nn", "lin", "lin_time_avg"], string
     44         The interpolation method to be used.
     45         It is ignored if ``interpolator`` is not `None`.
     46         Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`")
     47         or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with
     48         averaging across OFDM symbols (`"lin_time_avg"`).
     49         Defaults to "nn".
     50 
     51     interpolator : BaseChannelInterpolator
     52         An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`,
     53         such as :class:`~sionna.ofdm.LMMSEInterpolator`,
     54         or `None`. In the latter case, the interpolator specfied
     55         by ``interpolation_type`` is used.
     56         Otherwise, the ``interpolator`` is used and ``interpolation_type``
     57         is ignored.
     58         Defaults to `None`.
     59 
     60     dtype : tf.Dtype
     61         Datatype for internal calculations and the output dtype.
     62         Defaults to `tf.complex64`.
     63 
     64     Input
     65     -----
     66     (y, no) :
     67         Tuple:
     68 
     69     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex
     70         Observed resource grid
     71 
     72     no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float
     73         Variance of the AWGN
     74 
     75     Output
     76     ------
     77     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex
     78         Channel estimates accross the entire resource grid for all
     79         transmitters and streams
     80 
     81     err_var : Same shape as ``h_hat``, tf.float
     82         Channel estimation error variance accross the entire resource grid
     83         for all transmitters and streams
     84     """
     85     def __init__(self, resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs):
     86         super().__init__(dtype=dtype, **kwargs)
     87 
     88         assert isinstance(resource_grid, ResourceGrid),\
     89             "You must provide a valid instance of ResourceGrid."
     90         self._pilot_pattern = resource_grid.pilot_pattern
     91         self._removed_nulled_scs = RemoveNulledSubcarriers(resource_grid)
     92 
     93         assert interpolation_type in ["nn","lin","lin_time_avg",None], \
     94             "Unsupported `interpolation_type`"
     95         self._interpolation_type = interpolation_type
     96 
     97         if interpolator is not None:
     98             assert isinstance(interpolator, BaseChannelInterpolator), \
     99         "`interpolator` must implement the BaseChannelInterpolator interface"
    100             self._interpol = interpolator
    101         elif self._interpolation_type == "nn":
    102             self._interpol = NearestNeighborInterpolator(self._pilot_pattern)
    103         elif self._interpolation_type == "lin":
    104             self._interpol = LinearInterpolator(self._pilot_pattern)
    105         elif self._interpolation_type == "lin_time_avg":
    106             self._interpol = LinearInterpolator(self._pilot_pattern,
    107                                                 time_avg=True)
    108 
    109         # Precompute indices to gather received pilot signals
    110         num_pilot_symbols = self._pilot_pattern.num_pilot_symbols
    111         mask = flatten_last_dims(self._pilot_pattern.mask)
    112         pilot_ind = tf.argsort(mask, axis=-1, direction="DESCENDING")
    113         self._pilot_ind = pilot_ind[...,:num_pilot_symbols]
    114 
    115     @abstractmethod
    116     def estimate_at_pilot_locations(self, y_pilots, no):
    117         """
    118         Estimates the channel for the pilot-carrying resource elements.
    119 
    120         This is an abstract method that must be implemented by a concrete
    121         OFDM channel estimator that implement this class.
    122 
    123         Input
    124         -----
    125         y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, num_pilot_symbols], tf.complex
    126             Observed signals for the pilot-carrying resource elements
    127 
    128         no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float
    129             Variance of the AWGN
    130 
    131         Output
    132         ------
    133         h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, num_pilot_symbols], tf.complex
    134             Channel estimates for the pilot-carrying resource elements
    135 
    136         err_var : Same shape as ``h_hat``, tf.float
    137             Channel estimation error variance for the pilot-carrying
    138             resource elements
    139         """
    140         pass
    141 
    142     def call(self, inputs):
    143 
    144         y, no = inputs
    145 
    146         # y has shape:
    147         # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,..
    148         # ... fft_size]
    149         #
    150         # no can have shapes [], [batch_size], [batch_size, num_rx]
    151         # or [batch_size, num_rx, num_rx_ant]
    152 
    153         # Removed nulled subcarriers (guards, dc)
    154         y_eff = self._removed_nulled_scs(y)
    155 
    156         # Flatten the resource grid for pilot extraction
    157         # New shape: [...,num_ofdm_symbols*num_effective_subcarriers]
    158         y_eff_flat = flatten_last_dims(y_eff)
    159 
    160         # Gather pilots along the last dimensions
    161         # Resulting shape: y_eff_flat.shape[:-1] + pilot_ind.shape, i.e.:
    162         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,...
    163         #  ..., num_pilot_symbols]
    164         y_pilots = tf.gather(y_eff_flat, self._pilot_ind, axis=-1)
    165 
    166         # Compute LS channel estimates
    167         # Note: Some might be Inf because pilots=0, but we do not care
    168         # as only the valid estimates will be considered during interpolation.
    169         # We do a save division to replace Inf by 0.
    170         # Broadcasting from pilots here is automatic since pilots have shape
    171         # [num_tx, num_streams, num_pilot_symbols]
    172         h_hat, err_var = self.estimate_at_pilot_locations(y_pilots, no)
    173 
    174         # Interpolate channel estimates over the resource grid
    175         if self._interpolation_type is not None:
    176             h_hat, err_var = self._interpol(h_hat, err_var)
    177             err_var = tf.maximum(err_var, tf.cast(0, err_var.dtype))
    178 
    179         return h_hat, err_var
    180 
    181 
    182 class LSChannelEstimator(BaseChannelEstimator, Layer):
    183     # pylint: disable=line-too-long
    184     r"""LSChannelEstimator(resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs)
    185 
    186     Layer implementing least-squares (LS) channel estimation for OFDM MIMO systems.
    187 
    188     After LS channel estimation at the pilot positions, the channel estimates
    189     and error variances are interpolated accross the entire resource grid using
    190     a specified interpolation function.
    191 
    192     For simplicity, the underlying algorithm is described for a vectorized observation,
    193     where we have a nonzero pilot for all elements to be estimated.
    194     The actual implementation works on a full OFDM resource grid with sparse
    195     pilot patterns. The following model is assumed:
    196 
    197     .. math::
    198 
    199         \mathbf{y} = \mathbf{h}\odot\mathbf{p} + \mathbf{n}
    200 
    201     where :math:`\mathbf{y}\in\mathbb{C}^{M}` is the received signal vector,
    202     :math:`\mathbf{p}\in\mathbb{C}^M` is the vector of pilot symbols,
    203     :math:`\mathbf{h}\in\mathbb{C}^{M}` is the channel vector to be estimated,
    204     and :math:`\mathbf{n}\in\mathbb{C}^M` is a zero-mean noise vector whose
    205     elements have variance :math:`N_0`. The operator :math:`\odot` denotes
    206     element-wise multiplication.
    207 
    208     The channel estimate :math:`\hat{\mathbf{h}}` and error variances
    209     :math:`\sigma^2_i`, :math:`i=0,\dots,M-1`, are computed as
    210 
    211     .. math::
    212 
    213         \hat{\mathbf{h}} &= \mathbf{y} \odot
    214                            \frac{\mathbf{p}^\star}{\left|\mathbf{p}\right|^2}
    215                          = \mathbf{h} + \tilde{\mathbf{h}}\\
    216              \sigma^2_i &= \mathbb{E}\left[\tilde{h}_i \tilde{h}_i^\star \right]
    217                          = \frac{N_0}{\left|p_i\right|^2}.
    218 
    219     The channel estimates and error variances are then interpolated accross
    220     the entire resource grid.
    221 
    222     Parameters
    223     ----------
    224     resource_grid : ResourceGrid
    225         An instance of :class:`~sionna.ofdm.ResourceGrid`.
    226 
    227     interpolation_type : One of ["nn", "lin", "lin_time_avg"], string
    228         The interpolation method to be used.
    229         It is ignored if ``interpolator`` is not `None`.
    230         Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`")
    231         or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with
    232         averaging across OFDM symbols (`"lin_time_avg"`).
    233         Defaults to "nn".
    234 
    235     interpolator : BaseChannelInterpolator
    236         An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`,
    237         such as :class:`~sionna.ofdm.LMMSEInterpolator`,
    238         or `None`. In the latter case, the interpolator specfied
    239         by ``interpolation_type`` is used.
    240         Otherwise, the ``interpolator`` is used and ``interpolation_type``
    241         is ignored.
    242         Defaults to `None`.
    243 
    244     dtype : tf.Dtype
    245         Datatype for internal calculations and the output dtype.
    246         Defaults to `tf.complex64`.
    247 
    248     Input
    249     -----
    250     (y, no) :
    251         Tuple:
    252 
    253     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex
    254         Observed resource grid
    255 
    256     no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float
    257         Variance of the AWGN
    258 
    259     Output
    260     ------
    261     h_ls : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex
    262         Channel estimates accross the entire resource grid for all
    263         transmitters and streams
    264 
    265     err_var : Same shape as ``h_ls``, tf.float
    266         Channel estimation error variance accross the entire resource grid
    267         for all transmitters and streams
    268     """
    269 
    270     def estimate_at_pilot_locations(self, y_pilots, no):
    271 
    272         # y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams,
    273         #               num_pilot_symbols], tf.complex
    274         #     The observed signals for the pilot-carrying resource elements.
    275 
    276         # no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims,
    277         #   tf.float
    278         #     The variance of the AWGN.
    279 
    280         # Compute LS channel estimates
    281         # Note: Some might be Inf because pilots=0, but we do not care
    282         # as only the valid estimates will be considered during interpolation.
    283         # We do a save division to replace Inf by 0.
    284         # Broadcasting from pilots here is automatic since pilots have shape
    285         # [num_tx, num_streams, num_pilot_symbols]
    286         h_ls = tf.math.divide_no_nan(y_pilots, self._pilot_pattern.pilots)
    287 
    288         # Compute error variance and broadcast to the same shape as h_ls
    289         # Expand rank of no for broadcasting
    290         no = expand_to_rank(no, tf.rank(h_ls), -1)
    291 
    292         # Expand rank of pilots for broadcasting
    293         pilots = expand_to_rank(self._pilot_pattern.pilots, tf.rank(h_ls), 0)
    294 
    295         # Compute error variance, broadcastable to the shape of h_ls
    296         err_var = tf.math.divide_no_nan(no, tf.abs(pilots)**2)
    297 
    298         return h_ls, err_var
    299 
    300 
    301 class BaseChannelInterpolator(ABC):
    302     # pylint: disable=line-too-long
    303     r"""BaseChannelInterpolator()
    304 
    305     Abstract layer for implementing an OFDM channel interpolator.
    306 
    307     Any layer that implements an OFDM channel interpolator must implement this
    308     callable class.
    309 
    310     A channel interpolator is used by an OFDM channel estimator
    311     (:class:`~sionna.ofdm.BaseChannelEstimator`) to compute channel estimates
    312     for the data-carrying resource elements from the channel estimates for the
    313     pilot-carrying resource elements.
    314 
    315     Input
    316     -----
    317     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
    318         Channel estimates for the pilot-carrying resource elements
    319 
    320     err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
    321         Channel estimation error variances for the pilot-carrying resource elements
    322 
    323     Output
    324     ------
    325     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
    326         Channel estimates accross the entire resource grid for all
    327         transmitters and streams
    328 
    329     err_var : Same shape as ``h_hat``, tf.float
    330         Channel estimation error variance accross the entire resource grid
    331         for all transmitters and streams
    332     """
    333 
    334     @abstractmethod
    335     def __call__(self, h_hat, err_var):
    336         pass
    337 
    338 
    339 class NearestNeighborInterpolator(BaseChannelInterpolator):
    340     # pylint: disable=line-too-long
    341     r"""NearestNeighborInterpolator(pilot_pattern)
    342 
    343     Nearest-neighbor channel estimate interpolation on a resource grid.
    344 
    345     This class assigns to each element of an OFDM resource grid one of
    346     ``num_pilots`` provided channel estimates and error
    347     variances according to the nearest neighbor method. It is assumed
    348     that the measurements were taken at the nonzero positions of a
    349     :class:`~sionna.ofdm.PilotPattern`.
    350 
    351     The figure below shows how four channel estimates are interpolated
    352     accross a resource grid. Grey fields indicate measurement positions
    353     while the colored regions show which resource elements are assigned
    354     to the same measurement value.
    355 
    356     .. image:: ../figures/nearest_neighbor_interpolation.png
    357 
    358     Parameters
    359     ----------
    360     pilot_pattern : PilotPattern
    361         An instance of :class:`~sionna.ofdm.PilotPattern`
    362 
    363     Input
    364     -----
    365     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
    366         Channel estimates for the pilot-carrying resource elements
    367 
    368     err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
    369         Channel estimation error variances for the pilot-carrying resource elements
    370 
    371     Output
    372     ------
    373     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
    374         Channel estimates accross the entire resource grid for all
    375         transmitters and streams
    376 
    377     err_var : Same shape as ``h_hat``, tf.float
    378         Channel estimation error variances accross the entire resource grid
    379         for all transmitters and streams
    380     """
    381     def __init__(self, pilot_pattern):
    382         super().__init__()
    383 
    384         assert(pilot_pattern.num_pilot_symbols>0),\
    385             """The pilot pattern cannot be empty"""
    386 
    387         # Reshape mask to shape [-1,num_ofdm_symbols,num_effective_subcarriers]
    388         mask = np.array(pilot_pattern.mask)
    389         mask_shape = mask.shape # Store to reconstruct the original shape
    390         mask = np.reshape(mask, [-1] + list(mask_shape[-2:]))
    391 
    392         # Reshape the pilots to shape [-1, num_pilot_symbols]
    393         pilots = pilot_pattern.pilots
    394         pilots = np.reshape(pilots, [-1] + [pilots.shape[-1]])
    395 
    396         max_num_zero_pilots = np.max(np.sum(np.abs(pilots)==0, -1))
    397         assert max_num_zero_pilots<pilots.shape[-1],\
    398             """Each pilot sequence must have at least one nonzero entry"""
    399 
    400         # Compute gather indices for nearest neighbor interpolation
    401         gather_ind = np.zeros_like(mask, dtype=np.int32)
    402         for a in range(gather_ind.shape[0]): # For each pilot pattern...
    403             i_p, j_p = np.where(mask[a]) # ...determine the pilot indices
    404 
    405             for i in range(mask_shape[-2]): # Iterate over...
    406                 for j in range(mask_shape[-1]): # ... all resource elements
    407 
    408                     # Compute Manhattan distance to all pilot positions
    409                     d = np.abs(i-i_p) + np.abs(j-j_p)
    410 
    411                     # Set the distance at all pilot positions with zero energy
    412                     # equal to the maximum possible distance
    413                     d[np.abs(pilots[a])==0] = np.sum(mask_shape[-2:])
    414 
    415                     # Find the pilot index with the shortest distance...
    416                     ind = np.argmin(d)
    417 
    418                     # ... and store it in the index tensor
    419                     gather_ind[a, i, j] = ind
    420 
    421         # Reshape to the original shape of the mask, i.e.:
    422         # [num_tx, num_streams_per_tx, num_ofdm_symbols,...
    423         #  ..., num_effective_subcarriers]
    424         self._gather_ind = tf.reshape(gather_ind, mask_shape)
    425 
    426     def _interpolate(self, inputs):
    427         # inputs has shape:
    428         # [k, l, m, num_tx, num_streams_per_tx, num_pilots]
    429 
    430         # Transpose inputs to bring batch_dims for gather last. New shape:
    431         # [num_tx, num_streams_per_tx, num_pilots, k, l, m]
    432         perm = tf.roll(tf.range(tf.rank(inputs)), -3, 0)
    433         inputs = tf.transpose(inputs, perm)
    434 
    435         # Interpolate through gather. Shape:
    436         # [num_tx, num_streams_per_tx, num_ofdm_symbols,
    437         #  ..., num_effective_subcarriers, k, l, m]
    438         outputs = tf.gather(inputs, self._gather_ind, 2, batch_dims=2)
    439 
    440         # Transpose outputs to bring batch_dims first again. New shape:
    441         # [k, l, m, num_tx, num_streams_per_tx,...
    442         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    443         perm = tf.roll(tf.range(tf.rank(outputs)), 3, 0)
    444         outputs = tf.transpose(outputs, perm)
    445 
    446         return outputs
    447 
    448     def __call__(self, h_hat, err_var):
    449 
    450         h_hat = self._interpolate(h_hat)
    451         err_var = self._interpolate(err_var)
    452         return h_hat, err_var
    453 
    454 
    455 class LinearInterpolator(BaseChannelInterpolator):
    456     # pylint: disable=line-too-long
    457     r"""LinearInterpolator(pilot_pattern, time_avg=False)
    458 
    459     Linear channel estimate interpolation on a resource grid.
    460 
    461     This class computes for each element of an OFDM resource grid
    462     a channel estimate based on ``num_pilots`` provided channel estimates and
    463     error variances through linear interpolation.
    464     It is assumed that the measurements were taken at the nonzero positions
    465     of a :class:`~sionna.ofdm.PilotPattern`.
    466 
    467     The interpolation is done first across sub-carriers and then
    468     across OFDM symbols.
    469 
    470     Parameters
    471     ----------
    472     pilot_pattern : PilotPattern
    473         An instance of :class:`~sionna.ofdm.PilotPattern`
    474 
    475     time_avg : bool
    476         If enabled, measurements will be averaged across OFDM symbols
    477         (i.e., time). This is useful for channels that do not vary
    478         substantially over the duration of an OFDM frame. Defaults to `False`.
    479 
    480     Input
    481     -----
    482     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
    483         Channel estimates for the pilot-carrying resource elements
    484 
    485     err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
    486         Channel estimation error variances for the pilot-carrying resource elements
    487 
    488     Output
    489     ------
    490     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
    491         Channel estimates accross the entire resource grid for all
    492         transmitters and streams
    493 
    494     err_var : Same shape as ``h_hat``, tf.float
    495         Channel estimation error variances accross the entire resource grid
    496         for all transmitters and streams
    497     """
    498     def __init__(self, pilot_pattern, time_avg=False):
    499         super().__init__()
    500 
    501         assert(pilot_pattern.num_pilot_symbols>0),\
    502             """The pilot pattern cannot be empty"""
    503 
    504         self._time_avg = time_avg
    505 
    506         # Reshape mask to shape [-1,num_ofdm_symbols,num_effective_subcarriers]
    507         mask = np.array(pilot_pattern.mask)
    508         mask_shape = mask.shape # Store to reconstruct the original shape
    509         mask = np.reshape(mask, [-1] + list(mask_shape[-2:]))
    510 
    511         # Reshape the pilots to shape [-1, num_pilot_symbols]
    512         pilots = pilot_pattern.pilots
    513         pilots = np.reshape(pilots, [-1] + [pilots.shape[-1]])
    514 
    515         max_num_zero_pilots = np.max(np.sum(np.abs(pilots)==0, -1))
    516         assert max_num_zero_pilots<pilots.shape[-1],\
    517             """Each pilot sequence must have at least one nonzero entry"""
    518 
    519         # Create actual pilot patterns for each stream over the resource grid
    520         z = np.zeros_like(mask, dtype=pilots.dtype)
    521         for a in range(z.shape[0]):
    522             z[a][np.where(mask[a])] = pilots[a]
    523 
    524         # Linear interpolation works as follows:
    525         # We compute for each resource element (RE)
    526         # x_0 : The x-value (i.e., sub-carrier index or OFDM symbol) at which
    527         #       the first channel measurement was taken
    528         # x_1 : The x-value (i.e., sub-carrier index or OFDM symbol) at which
    529         #       the second channel measurement was taken
    530         # y_0 : The first channel estimate
    531         # y_1 : The second channel estimate
    532         # x   : The x-value (i.e., sub-carrier index or OFDM symbol)
    533         #
    534         # The linearly interpolated value y is then given as:
    535         # y = (x-x_0) * (y_1-y_0) / (x_1-x_0) + y_0
    536         #
    537         # The following code pre-computes various quantities and indices
    538         # that are needed to compute x_0, x_1, y_0, y_1, x for frequency- and
    539         # time-domain interpolation.
    540 
    541         ##
    542         ## Frequency-domain interpolation
    543         ##
    544         self._x_freq = tf.cast(expand_to_rank(tf.range(0, mask.shape[-1]),
    545                                               7,
    546                                               axis=0),
    547                                pilots.dtype)
    548 
    549         # Permutation indices to shift batch_dims last during gather
    550         self._perm_fwd_freq = tf.roll(tf.range(6), -3, 0)
    551 
    552         x_0_freq = np.zeros_like(mask, np.int32)
    553         x_1_freq = np.zeros_like(mask, np.int32)
    554 
    555         # Set REs of OFDM symbols without any pilot equal to -1 (dummy value)
    556         x_0_freq[np.sum(np.abs(z), axis=-1)==0] = -1
    557         x_1_freq[np.sum(np.abs(z), axis=-1)==0] = -1
    558 
    559         y_0_freq_ind = np.copy(x_0_freq) # Indices used to gather estimates
    560         y_1_freq_ind = np.copy(x_1_freq) # Indices used to gather estimates
    561 
    562         # For each stream
    563         for a in range(z.shape[0]):
    564 
    565             pilot_count = 0 # Counts the number of non-zero pilots
    566 
    567             # Indices of non-zero pilots within the pilots vector
    568             pilot_ind = np.where(np.abs(pilots[a]))[0]
    569 
    570             # Go through all OFDM symbols
    571             for i in range(x_0_freq.shape[1]):
    572 
    573                 # Indices of non-zero pilots within the OFDM symbol
    574                 pilot_ind_ofdm = np.where(np.abs(z[a][i]))[0]
    575 
    576                 # If OFDM symbol contains only one non-zero pilot
    577                 if len(pilot_ind_ofdm)==1:
    578                     # Set the indices of the first and second pilot to the same
    579                     # value for all REs of the OFDM symbol
    580                     x_0_freq[a][i] = pilot_ind_ofdm[0]
    581                     x_1_freq[a][i] = pilot_ind_ofdm[0]
    582                     y_0_freq_ind[a,i] = pilot_ind[pilot_count]
    583                     y_1_freq_ind[a,i] = pilot_ind[pilot_count]
    584 
    585                 # If OFDM symbol contains two or more pilots
    586                 elif len(pilot_ind_ofdm)>=2:
    587                     x0 = 0
    588                     x1 = 1
    589 
    590                     # Go through all resource elements of this OFDM symbol
    591                     for j in range(x_0_freq.shape[2]):
    592                         x_0_freq[a,i,j] = pilot_ind_ofdm[x0]
    593                         x_1_freq[a,i,j] = pilot_ind_ofdm[x1]
    594                         y_0_freq_ind[a,i,j] = pilot_ind[pilot_count + x0]
    595                         y_1_freq_ind[a,i,j] = pilot_ind[pilot_count + x1]
    596                         if j==pilot_ind_ofdm[x1] and x1<len(pilot_ind_ofdm)-1:
    597                             x0 = x1
    598                             x1 += 1
    599 
    600                 pilot_count += len(pilot_ind_ofdm)
    601 
    602         x_0_freq = np.reshape(x_0_freq, mask_shape)
    603         x_1_freq = np.reshape(x_1_freq, mask_shape)
    604         x_0_freq = expand_to_rank(x_0_freq, 7, axis=0)
    605         x_1_freq = expand_to_rank(x_1_freq, 7, axis=0)
    606         self._x_0_freq = tf.cast(x_0_freq, pilots.dtype)
    607         self._x_1_freq = tf.cast(x_1_freq, pilots.dtype)
    608 
    609         # We add +1 here to shift all indices as the input will be padded
    610         # at the beginning with 0, (i.e., the dummy index -1 will become 0).
    611         self._y_0_freq_ind = np.reshape(y_0_freq_ind, mask_shape)+1
    612         self._y_1_freq_ind = np.reshape(y_1_freq_ind, mask_shape)+1
    613 
    614         ##
    615         ## Time-domain interpolation
    616         ##
    617         self._x_time = tf.expand_dims(tf.range(0, mask.shape[-2]), -1)
    618         self._x_time = tf.cast(expand_to_rank(self._x_time, 7, axis=0),
    619                                dtype=pilots.dtype)
    620 
    621         # Indices used to gather estimates
    622         self._perm_fwd_time = tf.roll(tf.range(7), -3, 0)
    623 
    624         y_0_time_ind = np.zeros(z.shape[:2], np.int32) # Gather indices
    625         y_1_time_ind = np.zeros(z.shape[:2], np.int32) # Gather indices
    626 
    627         # For each stream
    628         for a in range(z.shape[0]):
    629 
    630             # Indices of OFDM symbols for which channel estimates were computed
    631             ofdm_ind = np.where(np.sum(np.abs(z[a]), axis=-1))[0]
    632 
    633             # Only one OFDM symbol with pilots
    634             if len(ofdm_ind)==1:
    635                 y_0_time_ind[a] = ofdm_ind[0]
    636                 y_1_time_ind[a] = ofdm_ind[0]
    637 
    638             # Two or more OFDM symbols with pilots
    639             elif len(ofdm_ind)>=2:
    640                 x0 = 0
    641                 x1 = 1
    642                 for i in range(z.shape[1]):
    643                     y_0_time_ind[a,i] = ofdm_ind[x0]
    644                     y_1_time_ind[a,i] = ofdm_ind[x1]
    645                     if i==ofdm_ind[x1] and x1<len(ofdm_ind)-1:
    646                         x0 = x1
    647                         x1 += 1
    648 
    649         self._y_0_time_ind = np.reshape(y_0_time_ind, mask_shape[:-1])
    650         self._y_1_time_ind = np.reshape(y_1_time_ind, mask_shape[:-1])
    651 
    652         self._x_0_time = expand_to_rank(tf.expand_dims(self._y_0_time_ind, -1),
    653                                                        7, axis=0)
    654         self._x_0_time = tf.cast(self._x_0_time, dtype=pilots.dtype)
    655         self._x_1_time = expand_to_rank(tf.expand_dims(self._y_1_time_ind, -1),
    656                                                        7, axis=0)
    657         self._x_1_time = tf.cast(self._x_1_time, dtype=pilots.dtype)
    658 
    659         #
    660         # Other precomputed values
    661         #
    662         # Undo permutation of batch_dims for gather
    663         self._perm_bwd = tf.roll(tf.range(7), 3, 0)
    664 
    665         # Padding for the inputs
    666         pad = np.zeros([6, 2], np.int32)
    667         pad[-1, 0] = 1
    668         self._pad = pad
    669 
    670         # Number of ofdm symbols carrying at least one pilot.
    671         # Used for time-averaging (optional)
    672         n = np.sum(np.abs(np.reshape(z, mask_shape)), axis=-1, keepdims=True)
    673         n = np.sum(n>0, axis=-2, keepdims=True)
    674         self._num_pilot_ofdm_symbols = expand_to_rank(n, 7, axis=0)
    675 
    676 
    677     def _interpolate_1d(self, inputs, x, x0, x1, y0_ind, y1_ind):
    678         # Gather the right values for y0 and y1
    679         y0 = tf.gather(inputs, y0_ind, axis=2, batch_dims=2)
    680         y1 = tf.gather(inputs, y1_ind, axis=2, batch_dims=2)
    681 
    682         # Undo the permutation of the inputs
    683         y0 = tf.transpose(y0, self._perm_bwd)
    684         y1 = tf.transpose(y1, self._perm_bwd)
    685 
    686         # Compute linear interpolation
    687         slope = tf.math.divide_no_nan(y1-y0, tf.cast(x1-x0, dtype=y0.dtype))
    688         return tf.cast(x-x0, dtype=y0.dtype)*slope + y0
    689 
    690     def _interpolate(self, inputs):
    691         #
    692         # Prepare inputs
    693         #
    694         # inputs has shape:
    695         # [k, l, m, num_tx, num_streams_per_tx, num_pilots]
    696 
    697         # Pad the inputs with a leading 0.
    698         # All undefined channel estimates will get this value.
    699         inputs = tf.pad(inputs, self._pad, constant_values=0)
    700 
    701         # Transpose inputs to bring batch_dims for gather last. New shape:
    702         # [num_tx, num_streams_per_tx, 1+num_pilots, k, l, m]
    703         inputs = tf.transpose(inputs, self._perm_fwd_freq)
    704 
    705         #
    706         # Frequency-domain interpolation
    707         #
    708         # h_hat_freq has shape:
    709         # [k, l, m, num_tx, num_streams_per_tx, num_ofdm_symbols,...
    710         #  ...num_effective_subcarriers]
    711         h_hat_freq = self._interpolate_1d(inputs,
    712                                           self._x_freq,
    713                                           self._x_0_freq,
    714                                           self._x_1_freq,
    715                                           self._y_0_freq_ind,
    716                                           self._y_1_freq_ind)
    717         #
    718         # Time-domain interpolation
    719         #
    720 
    721         # Time-domain averaging (optional)
    722         if self._time_avg:
    723             num_ofdm_symbols = h_hat_freq.shape[-2]
    724             h_hat_freq = tf.reduce_sum(h_hat_freq, axis=-2, keepdims=True)
    725             h_hat_freq /= tf.cast(self._num_pilot_ofdm_symbols,h_hat_freq.dtype)
    726             h_hat_freq = tf.repeat(h_hat_freq, [num_ofdm_symbols], axis=-2)
    727 
    728         # Transpose h_hat_freq to bring batch_dims for gather last. New shape:
    729         # [num_tx, num_streams_per_tx, num_ofdm_symbols,...
    730         #  ...num_effective_subcarriers, k, l, m]
    731         h_hat_time = tf.transpose(h_hat_freq, self._perm_fwd_time)
    732 
    733         # h_hat_time has shape:
    734         # [k, l, m, num_tx, num_streams_per_tx, num_ofdm_symbols,...
    735         #  ...num_effective_subcarriers]
    736         h_hat_time = self._interpolate_1d(h_hat_time,
    737                                           self._x_time,
    738                                           self._x_0_time,
    739                                           self._x_1_time,
    740                                           self._y_0_time_ind,
    741                                           self._y_1_time_ind)
    742 
    743         return h_hat_time
    744 
    745     def __call__(self, h_hat, err_var):
    746 
    747         h_hat = self._interpolate(h_hat)
    748 
    749         # the interpolator requires complex-valued inputs
    750         err_var = tf.cast(err_var, tf.complex64)
    751         err_var = self._interpolate(err_var)
    752         err_var = tf.math.real(err_var)
    753 
    754         return h_hat, err_var
    755 
    756 
    757 class LMMSEInterpolator1D:
    758     # pylint: disable=line-too-long
    759     r"""LMMSEInterpolator1D(pilot_mask, cov_mat)
    760 
    761     This class performs the linear interpolation across the inner dimension of the input ``h_hat``.
    762 
    763     The two inner dimensions of the input ``h_hat`` form a matrix :math:`\hat{\mathbf{H}} \in \mathbb{C}^{N \times M}`.
    764     LMMSE interpolation is performed across the inner dimension as follows:
    765 
    766     .. math::
    767         \tilde{\mathbf{h}}_n = \mathbf{A}_n \hat{\mathbf{h}}_n
    768 
    769     where :math:`1 \leq n \leq N` and :math:`\hat{\mathbf{h}}_n` is
    770     the :math:`n^{\text{th}}` (transposed) row of :math:`\hat{\mathbf{H}}`.
    771     :math:`\mathbf{A}_n` is the :math:`M \times M` interpolation LMMSE matrix:
    772 
    773     .. math::
    774         \mathbf{A}_n = \mathbf{R} \mathbf{\Pi}_n \left( \mathbf{\Pi}_n^\intercal \mathbf{R} \mathbf{\Pi}_n + \tilde{\mathbf{\Sigma}}_n \right)^{-1} \mathbf{\Pi}_n^\intercal.
    775 
    776     where :math:`\mathbf{R}` is the :math:`M \times M` covariance matrix across the inner dimension of the quantity which is estimated,
    777     :math:`\mathbf{\Pi}_n` the :math:`M \times K_n` matrix that spreads :math:`K_n`
    778     values to a vector of size :math:`M` according to the ``pilot_mask`` for the :math:`n^{\text{th}}` row,
    779     and :math:`\tilde{\mathbf{\Sigma}}_n \in \mathbb{R}^{K_n \times K_n}` is the regularized channel estimation error covariance.
    780     The :math:`i^{\text{th}}`` diagonal element of :math:`\tilde{\mathbf{\Sigma}}_n` is such that:
    781 
    782     .. math::
    783 
    784         \left[ \tilde{\mathbf{\Sigma}}_n \right]_{i,i} = \text{max} \left\{  \right\}
    785 
    786      built from ``err_var`` and assumed to be diagonal.
    787 
    788     The returned channel estimates are
    789 
    790     .. math::
    791         \begin{bmatrix}
    792             {\tilde{\mathbf{h}}_1}^\intercal\\
    793             \vdots\\
    794             {\tilde{\mathbf{h}}_N}^\intercal
    795         \end{bmatrix}.
    796 
    797     The returned channel estimation error variances are the diaginal coefficients of
    798 
    799     .. math::
    800         \text{diag} \left( \mathbf{R} - \mathbf{A}_n \mathbf{\Xi}_n \mathbf{R} \right), 1 \leq n \leq N
    801 
    802     where :math:`\mathbf{\Xi}_n` is the diagonal matrix of size :math:`M \times M` that zeros the
    803     columns corresponding to rows not carrying any pilots.
    804     Note that interpolation is not performed for rows not carrying any pilots.
    805 
    806     **Remark**: The interpolation matrix differs across rows as different
    807     rows may carry pilots on different elements and/or have different
    808     estimation error variances.
    809 
    810     Parameters
    811     ----------
    812     pilot_mask : [:math:`N`, :math:`M`] : int
    813         Mask indicating the allocation of resource elements.
    814         0 : Data,
    815         1 : Pilot,
    816         2 : Not used,
    817 
    818     cov_mat : [:math:`M`, :math:`M`], tf.complex
    819         Covariance matrix of the channel across the inner dimension.
    820 
    821     last_step : bool
    822         Set to `True` if this is the last interpolation step.
    823         Otherwise, set to `False`.
    824         If `True`, the the output is scaled to ensure its variance is as expected
    825         by the following interpolation step.
    826 
    827     Input
    828     -----
    829     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, :math:`N`, :math:`M`], tf.complex
    830         Channel estimates.
    831 
    832     err_var : [batch_size, num_rx, num_rx_ant, num_tx, :math:`N`, :math:`M`], tf.complex
    833         Channel estimation error variances.
    834 
    835     Output
    836     ------
    837     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, :math:`N`, :math:`M`], tf.complex
    838         Channel estimates interpolated across the inner dimension.
    839 
    840     err_var : Same shape as ``h_hat``, tf.float
    841         The channel estimation error variances of the interpolated channel estimates.
    842     """
    843 
    844     def __init__(self, pilot_mask, cov_mat, last_step):
    845 
    846         self._cdtype = cov_mat.dtype
    847         assert self._cdtype in (tf.complex64, tf.complex128),\
    848             "`cov_mat` dtype must be one of tf.complex64 or tf.complex128"
    849         self._rdtype = self._cdtype.real_dtype
    850         self._rzero = tf.constant(0.0, self._rdtype)
    851 
    852         # Interpolation is performed along the inner dimension of
    853         # the resource grid, which may be either the subcarriers
    854         # or the OFDM symbols dimension.
    855         # This dimension is referred to as the inner dimension.
    856         # The other dimension of the resource grid is referred to
    857         # as the outer dimension.
    858 
    859         # Size of the inner dimension.
    860         inner_dim_size = tf.shape(pilot_mask)[-1]
    861         self._inner_dim_size = inner_dim_size
    862 
    863         # Size of the outer dimension.
    864         outer_dim_size = tf.shape(pilot_mask)[-2]
    865         self._outer_dim_size = outer_dim_size
    866 
    867         self._cov_mat = cov_mat
    868         self._last_step = last_step
    869 
    870         # Computation of the interpolation matrix is done solving the
    871         # least-square problem:
    872         #
    873         # X = min_Z |AZ - B|_F^2
    874         #
    875         # where A = (\Pi_T R \Pi + S) and
    876         # B = R \Pi
    877         # where R is the channel covariance matrix, S the error
    878         # diagonal covariance matrix, and \Pi the matrix that spreads the pilots
    879         # according to the pilot pattern along the inner axis.
    880 
    881         # Extracting the locations of pilots from the pilot mask
    882         num_tx = tf.shape(pilot_mask)[0]
    883         num_streams_per_tx = tf.shape(pilot_mask)[1]
    884 
    885         # List of indices of pilots in the inner dimension for every
    886         # transmit antenna, stream, and outer dimension element.
    887         pilot_indices = []
    888         # Maximum number of pilots carried by an inner dimension.
    889         max_num_pil = 0
    890         # Indices used to add the error variance to the diagonal
    891         # elements of the covariance matrix restricted
    892         # to the elements carrying pilots.
    893         # These matrices are computed below.
    894         add_err_var_indices = np.zeros([num_tx, num_streams_per_tx,
    895                                         outer_dim_size, inner_dim_size, 5], int)
    896         for tx in range(num_tx):
    897             pilot_indices.append([])
    898             for st in range(num_streams_per_tx):
    899                 pilot_indices[-1].append([])
    900                 for oi in range(outer_dim_size):
    901                     pilot_indices[-1][-1].append([])
    902                     num_pil = 0 # Number of pilots on this outer dim
    903                     for ii in range(inner_dim_size):
    904                         # Check if this RE is carrying a pilot
    905                         # for this stream
    906                         if pilot_mask[tx,st,oi,ii] == 0:
    907                             continue
    908                         if pilot_mask[tx,st,oi,ii] == 1:
    909                             pilot_indices[tx][st][oi].append(ii)
    910                             indices = [tx, st, oi, num_pil, num_pil]
    911                             add_err_var_indices[tx, st, oi, ii] = indices
    912                             num_pil += 1
    913                     max_num_pil = max(max_num_pil, num_pil)
    914         # [num_tx, num_streams_per_tx, outer_dim_size, inner_dim_size, 5]
    915         self._add_err_var_indices = tf.cast(add_err_var_indices, tf.int32)
    916 
    917         # Different subcarriers/symbols may carry a different number of pilots.
    918         # To handle such cases, we create a tensor of square matrices of
    919         # size the maximum number of pilots carried by an inner dimension
    920         # and zero-padding is used to handle axes with less pilots than the
    921         # maximum value. The obtained structure is:
    922         #
    923         # |B 0|
    924         # |0 0|
    925         #
    926         pil_cov_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size,
    927                                 max_num_pil, max_num_pil], complex)
    928         for tx,st,oi in itertools.product(range(num_tx),
    929                                           range(num_streams_per_tx),
    930                                           range(outer_dim_size)):
    931             pil_ind = pilot_indices[tx][st][oi]
    932             num_pil = len(pil_ind)
    933             tmp = np.take(cov_mat, pil_ind, axis=0)
    934             pil_cov_mat_ = np.take(tmp, pil_ind, axis=1)
    935             pil_cov_mat[tx,st,oi,:num_pil,:num_pil] = pil_cov_mat_
    936         # [num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil]
    937         self._pil_cov_mat = tf.constant(pil_cov_mat, self._cdtype)
    938 
    939         # Pre-compute the covariance matrix with only the columns corresponding
    940         # to pilots.
    941         b_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size,
    942                                 max_num_pil, inner_dim_size], complex)
    943         for tx,st,oi in itertools.product(range(num_tx),
    944                                           range(num_streams_per_tx),
    945                                           range(outer_dim_size)):
    946             pil_ind = pilot_indices[tx][st][oi]
    947             num_pil = len(pil_ind)
    948             b_mat_ = np.take(cov_mat, pil_ind, axis=0)
    949             b_mat[tx,st,oi,:num_pil,:] = b_mat_
    950         self._b_mat = tf.constant(b_mat, self._cdtype)
    951 
    952         # Indices used to fill with zeros the columns of the interpolation
    953         # matrix not corresponding to zeros.
    954         # The results is a matrix of size inner_dim_size x inner_dim_size
    955         # where rows and columns not correspondong to pilots are set to zero.
    956         pil_loc = np.zeros([num_tx, num_streams_per_tx, outer_dim_size,
    957                             inner_dim_size, max_num_pil, 5], dtype=int)
    958         for tx,st,oi,p,ii in itertools.product(range(num_tx),
    959                                                 range(num_streams_per_tx),
    960                                                 range(outer_dim_size),
    961                                                 range(max_num_pil),
    962                                                 range(inner_dim_size)):
    963             if p >= len(pilot_indices[tx][st][oi]):
    964                 # An extra dummy subcarrier is added to push there padding
    965                 # identity matrix
    966                 pil_loc[tx, st, oi, ii, p] = [tx, st, oi,
    967                                               inner_dim_size,
    968                                               inner_dim_size]
    969             else:
    970                 pil_loc[tx, st, oi, ii, p] = [tx, st, oi,
    971                                               ii,
    972                                               pilot_indices[tx][st][oi][p]]
    973         self._pil_loc = tf.cast(pil_loc, tf.int32)
    974 
    975         # Covariance matrix for each stream with only the row corresponding
    976         # to a pilot carrying RE not set to 0.
    977         # This is required to compute the estimation error variances.
    978         err_var_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size,
    979                 inner_dim_size, inner_dim_size], complex)
    980         for tx,st,oi in itertools.product(range(num_tx),
    981                                           range(num_streams_per_tx),
    982                                           range(outer_dim_size)):
    983             pil_ind = pilot_indices[tx][st][oi]
    984             mask = np.zeros([inner_dim_size], complex)
    985             mask[pil_ind] = 1.0
    986             mask = np.expand_dims(mask, axis=1)
    987             err_var_mat[tx,st,oi] = cov_mat*mask
    988         self._err_var_mat = tf.constant(err_var_mat, self._cdtype)
    989 
    990     def __call__(self, h_hat, err_var):
    991 
    992         # h_hat : [batch_size, num_rx, num_rx_ant, num_tx,
    993         #          num_streams_per_tx, outer_dim_size, inner_dim_size]
    994         # err_var : [batch_size, num_rx, num_rx_ant, num_tx,
    995         #          num_streams_per_tx, outer_dim_size, inner_dim_size]
    996 
    997         batch_size = tf.shape(h_hat)[0]
    998         num_rx = tf.shape(h_hat)[1]
    999         num_rx_ant = tf.shape(h_hat)[2]
   1000         num_tx = tf.shape(h_hat)[3]
   1001         num_tx_stream = tf.shape(h_hat)[4]
   1002         outer_dim_size = self._outer_dim_size
   1003         inner_dim_size = self._inner_dim_size
   1004 
   1005         #####################################
   1006         # Compute the interpolation matrix
   1007         #####################################
   1008 
   1009         # Computation of the interpolation matrix is done solving the
   1010         # least-square problem:
   1011         #
   1012         # X = min_Z |AZ - B|_F^2
   1013         #
   1014         # where A = (\Pi_T R \Pi + S) and
   1015         # B = R \Pi
   1016         # where R is the channel covariance matrix, S the error
   1017         # diagonal covariance matrix, and \Pi the matrix that spreads the pilots
   1018         # according to the pilot pattern along the inner axis.
   1019 
   1020         #
   1021         # Computing A
   1022         #
   1023 
   1024         # Covariance matrices restricted to pilot locations
   1025         # [num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil]
   1026         pil_cov_mat = self._pil_cov_mat
   1027 
   1028         # Adding batch, receive, and receive antennas dimensions to the
   1029         # covariance matrices restricted to pilot locations and to the
   1030         # regularization values
   1031         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1032         #  outer_dim_size, max_num_pil, max_num_pil]
   1033         pil_cov_mat = expand_to_rank(pil_cov_mat, 8, 0)
   1034         pil_cov_mat = tf.tile(pil_cov_mat, [batch_size, num_rx, num_rx_ant,
   1035                                                      1, 1, 1, 1, 1])
   1036 
   1037         # Adding the noise variance to the covariance matrices restricted to
   1038         # pilots
   1039         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1040         #  outer_dim_size, max_num_pil, max_num_pil]
   1041         pil_cov_mat_ = tf.transpose(pil_cov_mat, [3, 4, 5, 6, 7, 0, 1, 2])
   1042         err_var_ = tf.complex(err_var, self._rzero)
   1043         err_var_ = tf.transpose(err_var_, [3, 4, 5, 6, 0, 1, 2])
   1044         a_mat = tf.tensor_scatter_nd_add(pil_cov_mat_,
   1045                                         self._add_err_var_indices, err_var_)
   1046         a_mat = tf.transpose(a_mat, [5, 6, 7, 0, 1, 2, 3, 4])
   1047 
   1048         #
   1049         # Computing B
   1050         #
   1051 
   1052         # B is pre-computed as it only depend on the channel covariance and
   1053         # pilot pattern.
   1054         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1055         #  outer_dim_size, max_num_pil, inner_dim_size]
   1056         b_mat = self._b_mat
   1057         b_mat = expand_to_rank(b_mat, 8, 0)
   1058         b_mat = tf.tile(b_mat, [batch_size, num_rx, num_rx_ant,
   1059                                 1, 1, 1, 1, 1])
   1060 
   1061         #
   1062         # Computing the interpolation matrix
   1063         #
   1064 
   1065         # Using lstsq to compute the columns of the interpolation matrix
   1066         # corresponding to pilots.
   1067         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1068         #  outer_dim_size, inner_dim_size, max_num_pil]
   1069         ext_mat = tf.linalg.lstsq(a_mat, b_mat, fast=False)
   1070         ext_mat = tf.transpose(ext_mat, [0,1,2,3,4,5,7,6], conjugate=True)
   1071 
   1072         # Filling with zeros the columns not corresponding to pilots.
   1073         # An extra dummy outer dim is added to scatter there the coefficients
   1074         # of the identity matrix used for padding.
   1075         # This dummy dim is then removed.
   1076         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1077         #  outer_dim_size, inner_dim_size, inner_dim_size]
   1078         ext_mat = tf.transpose(ext_mat, [3, 4, 5, 6, 7, 0, 1, 2])
   1079         ext_mat = tf.scatter_nd(self._pil_loc, ext_mat,
   1080                                             [num_tx, num_tx_stream,
   1081                                              outer_dim_size,
   1082                                              inner_dim_size+1,
   1083                                              inner_dim_size+1,
   1084                                              batch_size, num_rx, num_rx_ant])
   1085         ext_mat = tf.transpose(ext_mat, [5, 6, 7, 0, 1, 2, 3, 4])
   1086         ext_mat = ext_mat[...,:inner_dim_size,:inner_dim_size]
   1087 
   1088         ################################################
   1089         # Apply interpolation over the inner dimension
   1090         ################################################
   1091 
   1092         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1093         #  outer_dim_size, inner_dim_size]
   1094         h_hat = tf.expand_dims(h_hat, axis=-1)
   1095         h_hat = tf.matmul(ext_mat, h_hat)
   1096         h_hat = tf.squeeze(h_hat, axis=-1)
   1097 
   1098         ##############################
   1099         # Compute the error variances
   1100         ##############################
   1101 
   1102         # Keep track of the previous estimation error variances for later use
   1103         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1104         #  outer_dim_size, inner_dim_size]
   1105         err_var_old = err_var
   1106 
   1107         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1108         #  outer_dim_size, inner_dim_size]
   1109         cov_mat = expand_to_rank(self._cov_mat, 8, 0)
   1110         err_var = tf.linalg.diag_part(cov_mat)
   1111         err_var_mat = expand_to_rank(self._err_var_mat, 8, 0)
   1112         err_var_mat = tf.transpose(err_var_mat, [0, 1, 2, 3, 4, 5, 7, 6])
   1113         err_var = err_var - tf.reduce_sum(ext_mat*err_var_mat, axis=-1)
   1114         err_var = tf.math.real(err_var)
   1115         err_var = tf.maximum(err_var, self._rzero)
   1116 
   1117         #####################################
   1118         # If this is *not* the last
   1119         # interpolation step, scales the
   1120         # input `h_hat` to ensure
   1121         # it has the variance expected by the
   1122         # next interpolation step.
   1123         #
   1124         # The error variance also `err_var`
   1125         # is updated accordingly.
   1126         #####################################
   1127         if not self._last_step:
   1128             #
   1129             # Variance of h_hat
   1130             #
   1131             # Conjugate transpose of LMMSE matrix
   1132             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1133             #  outer_dim_size, inner_dim_size, inner_dim_size]
   1134             ext_mat_h = tf.transpose(ext_mat, [0, 1, 2, 3, 4, 5, 7, 6],
   1135                                      conjugate=True)
   1136             # First part of the estimate covariance
   1137             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1138             #  outer_dim_size, inner_dim_size, inner_dim_size]
   1139             h_hat_var_1 = tf.matmul(cov_mat, ext_mat_h)
   1140             h_hat_var_1 = tf.transpose(h_hat_var_1, [0, 1, 2, 3, 4, 5, 7, 6])
   1141             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1142             #  outer_dim_size, inner_dim_size]
   1143             h_hat_var_1 = tf.reduce_sum(ext_mat*h_hat_var_1, axis=-1)
   1144             # Second part of the estimate covariance
   1145             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1146             #  outer_dim_size, inner_dim_size]
   1147             err_var_old_c = tf.complex(err_var_old, self._rzero)
   1148             err_var_old_c = tf.expand_dims(err_var_old_c, axis=-1)
   1149             h_hat_var_2 = err_var_old_c*ext_mat_h
   1150             h_hat_var_2 = tf.transpose(h_hat_var_2, [0, 1, 2, 3, 4, 5, 7, 6])
   1151             h_hat_var_2 = tf.reduce_sum(ext_mat*h_hat_var_2, axis=-1)
   1152             # Variance of h_hat
   1153             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1154             #  outer_dim_size, inner_dim_size]
   1155             h_hat_var = h_hat_var_1 + h_hat_var_2
   1156             # Scaling factor
   1157             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1158             #  outer_dim_size, inner_dim_size]
   1159             err_var_c = tf.complex(err_var, self._rzero)
   1160             h_var = tf.linalg.diag_part(cov_mat)
   1161             s = tf.math.divide_no_nan(2.*h_var, h_hat_var + h_var - err_var_c)
   1162             # Apply scaling to estimate
   1163             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1164             #  outer_dim_size, inner_dim_size]
   1165             h_hat = s*h_hat
   1166             # Updated variance
   1167             # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1168             #  outer_dim_size, inner_dim_size]
   1169             err_var = s*(s-1.)*h_hat_var + (1.-s)*h_var + s*err_var_c
   1170             err_var = tf.math.real(err_var)
   1171             err_var = tf.maximum(err_var, self._rzero)
   1172 
   1173         return h_hat, err_var
   1174 
   1175 class SpatialChannelFilter:
   1176     # pylint: disable=line-too-long
   1177     r"""SpatialChannelFilter(cov_mat, last_step)
   1178 
   1179     Implements linear minimum mean square error (LMMSE) smoothing.
   1180 
   1181     We consider the following model:
   1182 
   1183     .. math::
   1184 
   1185         \mathbf{y} = \mathbf{h} + \mathbf{n}
   1186 
   1187     where :math:`\mathbf{y}\in\mathbb{C}^{M}` is the received signal vector,
   1188     :math:`\mathbf{h}\in\mathbb{C}^{M}` is the channel vector to be estimated
   1189     with covariance matrix
   1190     :math:`\mathbb{E}\left[ \mathbf{h} \mathbf{h}^{\mathsf{H}} \right] = \mathbf{R}`,
   1191     and :math:`\mathbf{n}\in\mathbb{C}^M` is a zero-mean noise vector whose
   1192     elements have variance :math:`N_0`.
   1193 
   1194     The channel estimate :math:`\hat{\mathbf{h}}` is computed as
   1195 
   1196     .. math::
   1197 
   1198         \hat{\mathbf{h}} &= \mathbf{A} \mathbf{y}
   1199 
   1200     where
   1201 
   1202     .. math::
   1203 
   1204         \mathbf{A} = \mathbf{R} \left( \mathbf{R} + N_0 \mathbf{I}_M \right)^{-1}
   1205 
   1206     where :math:`\mathbf{I}_M` is the :math:`M \times M` identity matrix.
   1207     The estimation error is:
   1208 
   1209     .. math::
   1210 
   1211         \tilde{h} = \mathbf{h} - \hat{\mathbf{h}}
   1212 
   1213     The error variances
   1214 
   1215     .. math::
   1216 
   1217              \sigma^2_i = \mathbb{E}\left[\tilde{h}_i \tilde{h}_i^\star \right], 0 \leq i \leq M-1
   1218 
   1219     are the diagonal elements of
   1220 
   1221     .. math::
   1222 
   1223         \mathbb{E}\left[\mathbf{\tilde{h}} \mathbf{\tilde{h}}^{\mathsf{H}} \right] = \mathbf{R} - \mathbf{A}\mathbf{R}.
   1224 
   1225 
   1226     Note
   1227     ----
   1228     If you want to use this function in Graph mode with XLA, i.e., within
   1229     a function that is decorated with ``@tf.function(jit_compile=True)``,
   1230     you must set ``sionna.Config.xla_compat=true``.
   1231     See :py:attr:`~sionna.Config.xla_compat`.
   1232 
   1233     Parameters
   1234     ----------
   1235     cov_mat : [num_rx_ant, num_rx_ant], tf.complex
   1236         Spatial covariance matrix of the channel
   1237 
   1238     last_step : bool
   1239         Set to `True` if this is the last interpolation step.
   1240         Otherwise, set to `False`.
   1241         If `True`, the the output is scaled to ensure its variance is as expected
   1242         by the following interpolation step.
   1243 
   1244     Input
   1245     -----
   1246     h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.complex
   1247         Channel estimates.
   1248 
   1249     err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.float
   1250         Channel estimation error variances.
   1251 
   1252     Output
   1253     ------
   1254     h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.complex
   1255         Channel estimates smoothed accross the spatial dimension
   1256 
   1257     err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.float
   1258         The channel estimation error variances of the smoothed channel estimates.
   1259     """
   1260 
   1261     def __init__(self, cov_mat, last_step):
   1262         self._rzero = tf.zeros((), cov_mat.dtype.real_dtype)
   1263         self._cov_mat = cov_mat
   1264         self._last_step = last_step
   1265 
   1266         # Indices for adding a tensor of vectors [..., num_rx_ant] to the
   1267         # diagonal of a tensor of matrices [..., num_rx_ant, num_rx_ant]
   1268         num_rx_ant = cov_mat.shape[0]
   1269         add_diag_indices = [[rxa, rxa] for rxa in range(num_rx_ant)]
   1270         self._add_diag_indices = tf.cast(add_diag_indices, tf.int32)
   1271 
   1272     def __call__(self, h_hat, err_var):
   1273         # h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx,
   1274         #           num_ofdm_symbols, num_subcarriers, num_rx_ant]
   1275         # err_var : [batch_size, num_rx, num_tx, num_streams_per_tx,
   1276         #           num_ofdm_symbols, num_subcarriers, num_rx_ant]
   1277 
   1278         # [..., num_rx_ant]
   1279         err_var = tf.complex(err_var, self._rzero)
   1280         # Keep track of the previous estimation error variances for later use
   1281         err_var_old = err_var
   1282 
   1283         # [num_rx_ant, num_rx_ant]
   1284         cov_mat = self._cov_mat
   1285         cov_mat_t = tf.transpose(cov_mat)
   1286         num_rx_ant = tf.shape(cov_mat)[0]
   1287 
   1288         ##########################################
   1289         # Compute LMMSE matrix
   1290         ##########################################
   1291 
   1292         # [..., num_rx_ant, num_rx_ant]
   1293         cov_mat = expand_to_rank(cov_mat, tf.rank(err_var)+1, axis=0)
   1294 
   1295         # Adding the error variances to the diagonal
   1296         # [..., num_rx_ant, num_rx_ant]
   1297         lmmse_mat = tf.broadcast_to(cov_mat, tf.concat([tf.shape(err_var),
   1298                                                         [num_rx_ant]], axis=0))
   1299         # [num_rx_ant, ...]
   1300         err_var_ = tf.transpose(err_var, [6, 0, 1, 2, 3, 4, 5])
   1301         # [num_rx_ant, num_rx_ant, ...]
   1302         lmmse_mat = tf.transpose(lmmse_mat, [6, 7, 0, 1, 2, 3, 4, 5])
   1303         lmmse_mat = tf.tensor_scatter_nd_add(lmmse_mat,
   1304                                             self._add_diag_indices, err_var_)
   1305         # [..., num_rx_ant, num_rx_ant]
   1306         lmmse_mat = tf.transpose(lmmse_mat, [2, 3, 4, 5, 6, 7, 0, 1])
   1307 
   1308         # [..., num_rx_ant, num_rx_ant]
   1309         lmmse_mat = matrix_inv(lmmse_mat)
   1310         lmmse_mat = tf.matmul(cov_mat, lmmse_mat)
   1311 
   1312         ##########################################
   1313         # Apply smoothing
   1314         ##########################################
   1315 
   1316         # [..., num_rx_ant, 1]
   1317         h_hat = tf.expand_dims(h_hat, axis=-1)
   1318         # [..., num_rx_ant]
   1319         h_hat = tf.squeeze(tf.matmul(lmmse_mat, h_hat), axis=-1)
   1320 
   1321         ##########################################
   1322         # Compute the estimation error variances
   1323         ##########################################
   1324 
   1325         # [..., num_rx_ant, num_rx_ant]
   1326         cov_mat_t = expand_to_rank(cov_mat_t, tf.rank(lmmse_mat), axis=0)
   1327         # [..., num_rx_ant]
   1328         err_var = tf.reduce_sum(cov_mat_t*lmmse_mat, axis=-1)
   1329         # [..., num_rx_ant]
   1330         err_var = tf.linalg.diag_part(cov_mat) - err_var
   1331         err_var = tf.math.real(err_var)
   1332         err_var = tf.maximum(err_var, self._rzero)
   1333 
   1334         ##########################################
   1335         # If this is *not* the last
   1336         # interpolation step, scales the
   1337         # input `h_hat` to ensure
   1338         # it has the variance expected by the
   1339         # next interpolation step.
   1340         #
   1341         # The error variance also `err_var`
   1342         # is updated accordingly.
   1343         ##########################################
   1344         if not self._last_step:
   1345             #
   1346             # Variance of h_hat
   1347             #
   1348             # Conjugate transpose of the LMMSE matrix
   1349             # [..., num_rx_ant, num_rx_ant]
   1350             lmmse_mat_h = tf.transpose(lmmse_mat, [0, 1, 2, 3, 4, 5, 7, 6],
   1351                                         conjugate=True)
   1352             # First part of the estimate covariance
   1353             # [..., num_rx_ant, num_rx_ant]
   1354             h_hat_var_1 = tf.matmul(cov_mat, lmmse_mat_h)
   1355             h_hat_var_1 = tf.transpose(h_hat_var_1, [0, 1, 2, 3, 4, 5, 7, 6])
   1356             # [..., num_rx_ant]
   1357             h_hat_var_1 = tf.reduce_sum(lmmse_mat*h_hat_var_1, axis=-1)
   1358             # Second part of the estimate covariance
   1359             # [..., num_rx_ant, 1]
   1360             err_var_old = tf.expand_dims(err_var_old, axis=-1)
   1361             # [..., num_rx_ant, num_rx_ant]
   1362             h_hat_var_2 = err_var_old*lmmse_mat_h
   1363             # [..., num_rx_ant, num_rx_ant]
   1364             h_hat_var_2 = tf.transpose(h_hat_var_2, [0, 1, 2, 3, 4, 5, 7, 6])
   1365             # [..., num_rx_ant]
   1366             h_hat_var_2 = tf.reduce_sum(lmmse_mat*h_hat_var_2, axis=-1)
   1367             # Variance of h_hat
   1368             # [..., num_rx_ant]
   1369             h_hat_var = h_hat_var_1 + h_hat_var_2
   1370             # Scaling factor
   1371             # [..., num_rx_ant]
   1372             err_var_c = tf.complex(err_var, self._rzero)
   1373             h_var = tf.linalg.diag_part(cov_mat)
   1374             s = tf.math.divide_no_nan(2.*h_var, h_hat_var + h_var - err_var_c)
   1375             # Apply scaling to estimate
   1376             # [..., num_rx_ant]
   1377             h_hat = s*h_hat
   1378             # Updated variance
   1379             # [..., num_rx_ant]
   1380             err_var = s*(s-1.)*h_hat_var + (1.-s)*h_var + s*err_var_c
   1381             err_var = tf.math.real(err_var)
   1382             err_var = tf.maximum(err_var, self._rzero)
   1383 
   1384         return h_hat, err_var
   1385 
   1386 
   1387 class LMMSEInterpolator(BaseChannelInterpolator):
   1388     # pylint: disable=line-too-long
   1389     r"""LMMSEInterpolator(pilot_pattern, cov_mat_time, cov_mat_freq, cov_mat_space=None, order='t-f')
   1390 
   1391     LMMSE interpolation on a resource grid with optional spatial smoothing.
   1392 
   1393     This class computes for each element of an OFDM resource grid
   1394     a channel estimate and error variance
   1395     through linear minimum mean square error (LMMSE) interpolation/smoothing.
   1396     It is assumed that the measurements were taken at the nonzero positions
   1397     of a :class:`~sionna.ofdm.PilotPattern`.
   1398 
   1399     Depending on the value of ``order``, the interpolation is carried out
   1400     accross time (t), i.e., OFDM symbols, frequency (f), i.e., subcarriers,
   1401     and optionally space (s), i.e., receive antennas, in any desired order.
   1402 
   1403     For simplicity, we describe the underlying algorithm assuming that interpolation
   1404     across the sub-carriers is performed first, followed by interpolation across
   1405     OFDM symbols, and finally by spatial smoothing across receive
   1406     antennas.
   1407     The algorithm is similar if interpolation and/or smoothing are performed in
   1408     a different order.
   1409     For clarity, antenna indices are omitted when describing frequency and time
   1410     interpolation, as the same process is applied to all the antennas.
   1411 
   1412     The input ``h_hat`` is first reshaped to a resource grid
   1413     :math:`\hat{\mathbf{H}} \in \mathbb{C}^{N \times M}`, by scattering the channel
   1414     estimates at pilot locations according to the ``pilot_pattern``. :math:`N`
   1415     denotes the number of OFDM symbols and :math:`M` the number of sub-carriers.
   1416 
   1417     The first pass consists in interpolating across the sub-carriers:
   1418 
   1419     .. math::
   1420         \hat{\mathbf{h}}_n^{(1)} = \mathbf{A}_n \hat{\mathbf{h}}_n
   1421 
   1422     where :math:`1 \leq n \leq N` is the OFDM symbol index and :math:`\hat{\mathbf{h}}_n` is
   1423     the :math:`n^{\text{th}}` (transposed) row of :math:`\hat{\mathbf{H}}`.
   1424     :math:`\mathbf{A}_n` is the :math:`M \times M` matrix such that:
   1425 
   1426     .. math::
   1427         \mathbf{A}_n = \bar{\mathbf{A}}_n \mathbf{\Pi}_n^\intercal
   1428 
   1429     where
   1430 
   1431     .. math::
   1432         \bar{\mathbf{A}}_n = \underset{\mathbf{Z} \in \mathbb{C}^{M \times K_n}}{\text{argmin}} \left\lVert \mathbf{Z}\left( \mathbf{\Pi}_n^\intercal \mathbf{R^{(f)}} \mathbf{\Pi}_n + \mathbf{\Sigma}_n \right) - \mathbf{R^{(f)}} \mathbf{\Pi}_n \right\rVert_{\text{F}}^2
   1433 
   1434     and :math:`\mathbf{R^{(f)}}` is the :math:`M \times M` channel frequency covariance matrix,
   1435     :math:`\mathbf{\Pi}_n` the :math:`M \times K_n` matrix that spreads :math:`K_n`
   1436     values to a vector of size :math:`M` according to the ``pilot_pattern`` for the :math:`n^{\text{th}}` OFDM symbol,
   1437     and :math:`\mathbf{\Sigma}_n \in \mathbb{R}^{K_n \times K_n}` is the channel estimation error covariance built from
   1438     ``err_var`` and assumed to be diagonal.
   1439     Computation of :math:`\bar{\mathbf{A}}_n` is done using an algorithm based on complete orthogonal decomposition.
   1440     This is done to avoid matrix inversion for badly conditioned covariance matrices.
   1441 
   1442     The channel estimation error variances after the first interpolation pass are computed as
   1443 
   1444     .. math::
   1445         \mathbf{\Sigma}^{(1)}_n = \text{diag} \left( \mathbf{R^{(f)}} - \mathbf{A}_n \mathbf{\Xi}_n \mathbf{R^{(f)}} \right)
   1446 
   1447     where :math:`\mathbf{\Xi}_n` is the diagonal matrix of size :math:`M \times M` that zeros the
   1448     columns corresponding to sub-carriers not carrying any pilots.
   1449     Note that interpolation is not performed for OFDM symbols which do not carry pilots.
   1450 
   1451     **Remark**: The interpolation matrix differs across OFDM symbols as different
   1452     OFDM symbols may carry pilots on different sub-carriers and/or have different
   1453     estimation error variances.
   1454 
   1455     Scaling of the estimates is then performed to ensure that their
   1456     variances match the ones expected by the next interpolation step, and the error variances are updated accordingly:
   1457 
   1458     .. math::
   1459         \begin{align}
   1460             \left[\hat{\mathbf{h}}_n^{(2)}\right]_m &= s_{n,m} \left[\hat{\mathbf{h}}_n^{(1)}\right]_m\\
   1461             \left[\mathbf{\Sigma}^{(2)}_n\right]_{m,m}  &= s_{n,m}\left( s_{n,m}-1 \right) \left[\hat{\mathbf{\Sigma}}^{(1)}_n\right]_{m,m} + \left( 1 - s_{n,m} \right) \left[\mathbf{R^{(f)}}\right]_{m,m} + s_{n,m} \left[\mathbf{\Sigma}^{(1)}_n\right]_{m,m}
   1462         \end{align}
   1463 
   1464     where the scaling factor :math:`s_{n,m}` is such that:
   1465 
   1466 
   1467     .. math::
   1468         \mathbb{E} \left\{ \left\lvert s_{n,m} \left[\hat{\mathbf{h}}_n^{(1)}\right]_m \right\rvert^2 \right\} = \left[\mathbf{R^{(f)}}\right]_{m,m} +  \mathbb{E} \left\{ \left\lvert s_{n,m} \left[\hat{\mathbf{h}}^{(1)}_n\right]_m - \left[\mathbf{h}_n\right]_m \right\rvert^2 \right\}
   1469 
   1470     which leads to:
   1471 
   1472     .. math::
   1473         \begin{align}
   1474             s_{n,m} &= \frac{2 \left[\mathbf{R^{(f)}}\right]_{m,m}}{\left[\mathbf{R^{(f)}}\right]_{m,m} - \left[\mathbf{\Sigma}^{(1)}_n\right]_{m,m} + \left[\hat{\mathbf{\Sigma}}^{(1)}_n\right]_{m,m}}\\
   1475             \hat{\mathbf{\Sigma}}^{(1)}_n &= \mathbf{A}_n \mathbf{R^{(f)}} \mathbf{A}_n^{\mathrm{H}}.
   1476         \end{align}
   1477 
   1478     The second pass consists in interpolating across the OFDM symbols:
   1479 
   1480     .. math::
   1481         \hat{\mathbf{h}}_m^{(3)} = \mathbf{B}_m \tilde{\mathbf{h}}^{(2)}_m
   1482 
   1483     where :math:`1 \leq m \leq M` is the sub-carrier index and :math:`\tilde{\mathbf{h}}^{(2)}_m` is
   1484     the :math:`m^{\text{th}}` column of
   1485 
   1486     .. math::
   1487         \hat{\mathbf{H}}^{(2)} = \begin{bmatrix}
   1488                                     {\hat{\mathbf{h}}_1^{(2)}}^\intercal\\
   1489                                     \vdots\\
   1490                                     {\hat{\mathbf{h}}_N^{(2)}}^\intercal
   1491                                  \end{bmatrix}
   1492 
   1493     and :math:`\mathbf{B}_m` is the :math:`N \times N` interpolation LMMSE matrix:
   1494 
   1495     .. math::
   1496         \mathbf{B}_m = \bar{\mathbf{B}}_m \tilde{\mathbf{\Pi}}_m^\intercal
   1497 
   1498     where
   1499 
   1500     .. math::
   1501         \bar{\mathbf{B}}_m = \underset{\mathbf{Z} \in \mathbb{C}^{N \times L_m}}{\text{argmin}} \left\lVert \mathbf{Z} \left( \tilde{\mathbf{\Pi}}_m^\intercal \mathbf{R^{(t)}}\tilde{\mathbf{\Pi}}_m + \tilde{\mathbf{\Sigma}}^{(2)}_m \right) -  \mathbf{R^{(t)}}\tilde{\mathbf{\Pi}}_m \right\rVert_{\text{F}}^2
   1502 
   1503     where :math:`\mathbf{R^{(t)}}` is the :math:`N \times N` channel time covariance matrix,
   1504     :math:`\tilde{\mathbf{\Pi}}_m` the :math:`N \times L_m` matrix that spreads :math:`L_m`
   1505     values to a vector of size :math:`N` according to the ``pilot_pattern`` for the :math:`m^{\text{th}}` sub-carrier,
   1506     and :math:`\tilde{\mathbf{\Sigma}}^{(2)}_m \in \mathbb{R}^{L_m \times L_m}` is the diagonal matrix of channel estimation error variances
   1507     built by gathering the error variances from (:math:`\mathbf{\Sigma}^{(2)}_1,\dots,\mathbf{\Sigma}^{(2)}_N`) corresponding
   1508     to resource elements carried by the :math:`m^{\text{th}}` sub-carrier.
   1509     Computation of :math:`\bar{\mathbf{B}}_m` is done using an algorithm based on complete orthogonal decomposition.
   1510     This is done to avoid matrix inversion for badly conditioned covariance matrices.
   1511 
   1512     The resulting channel estimate for the resource grid is
   1513 
   1514     .. math::
   1515         \hat{\mathbf{H}}^{(3)} = \left[ \hat{\mathbf{h}}_1^{(3)} \dots \hat{\mathbf{h}}_M^{(3)} \right]
   1516 
   1517     The resulting channel estimation error variances are the diagonal coefficients of the matrices
   1518 
   1519     .. math::
   1520         \mathbf{\Sigma}^{(3)}_m = \mathbf{R^{(t)}} - \mathbf{B}_m \tilde{\mathbf{\Xi}}_m \mathbf{R^{(t)}}, 1 \leq m \leq M
   1521 
   1522     where :math:`\tilde{\mathbf{\Xi}}_m` is the diagonal matrix of size :math:`N \times N` that zeros the
   1523     columns corresponding to OFDM symbols not carrying any pilots.
   1524 
   1525     **Remark**: The interpolation matrix differs across sub-carriers as different
   1526     sub-carriers may have different estimation error variances computed by the first
   1527     pass.
   1528     However, all sub-carriers carry at least one channel estimate as a result of
   1529     the first pass, ensuring that a channel estimate is computed for all the resource
   1530     elements after the second pass.
   1531 
   1532     **Remark:** LMMSE interpolation requires knowledge of the time and frequency
   1533     covariance matrices of the channel. The notebook `OFDM MIMO Channel Estimation and Detection <../examples/OFDM_MIMO_Detection.ipynb>`_ shows how to estimate
   1534     such matrices for arbitrary channel models.
   1535     Moreover, the functions :func:`~sionna.ofdm.tdl_time_cov_mat`
   1536     and :func:`~sionna.ofdm.tdl_freq_cov_mat` compute the expected time and frequency
   1537     covariance matrices, respectively, for the :class:`~sionna.channel.tr38901.TDL` channel models.
   1538 
   1539     Scaling of the estimates is then performed to ensure that their
   1540     variances match the ones expected by the next smoothing step, and the
   1541     error variances are updated accordingly:
   1542 
   1543     .. math::
   1544         \begin{align}
   1545             \left[\hat{\mathbf{h}}_m^{(4)}\right]_n &= \gamma_{m,n} \left[\hat{\mathbf{h}}_m^{(3)}\right]_n\\
   1546             \left[\mathbf{\Sigma}^{(4)}_m\right]_{n,n}  &= \gamma_{m,n}\left( \gamma_{m,n}-1 \right) \left[\hat{\mathbf{\Sigma}}^{(3)}_m\right]_{n,n} + \left( 1 - \gamma_{m,n} \right) \left[\mathbf{R^{(t)}}\right]_{n,n} + \gamma_{m,n} \left[\mathbf{\Sigma}^{(3)}_n\right]_{m,m}
   1547         \end{align}
   1548 
   1549     where:
   1550 
   1551     .. math::
   1552         \begin{align}
   1553             \gamma_{m,n} &= \frac{2 \left[\mathbf{R^{(t)}}\right]_{n,n}}{\left[\mathbf{R^{(t)}}\right]_{n,n} - \left[\mathbf{\Sigma}^{(3)}_m\right]_{n,n} + \left[\hat{\mathbf{\Sigma}}^{(3)}_n\right]_{m,m}}\\
   1554             \hat{\mathbf{\Sigma}}^{(3)}_m &= \mathbf{B}_m \mathbf{R^{(t)}} \mathbf{B}_m^{\mathrm{H}}
   1555         \end{align}
   1556 
   1557     Finally, a spatial smoothing step is applied to every resource element carrying
   1558     a channel estimate.
   1559     For clarity, we drop the resource element indexing :math:`(n,m)`.
   1560     We denote by :math:`L` the number of receive antennas, and by
   1561     :math:`\mathbf{R^{(s)}}\in\mathbb{C}^{L \times L}` the spatial covariance matrix.
   1562 
   1563     LMMSE spatial smoothing consists in the following computations:
   1564 
   1565     .. math::
   1566         \hat{\mathbf{h}}^{(5)} = \mathbf{C} \hat{\mathbf{h}}^{(4)}
   1567 
   1568     where
   1569 
   1570     .. math::
   1571         \mathbf{C} = \mathbf{R^{(s)}} \left( \mathbf{R^{(s)}} + \mathbf{\Sigma}^{(4)} \right)^{-1}.
   1572 
   1573     The estimation error variances are the digonal coefficients of
   1574 
   1575     .. math::
   1576         \mathbf{\Sigma}^{(5)} = \mathbf{R^{(s)}} - \mathbf{C}\mathbf{R^{(s)}}
   1577 
   1578     The smoothed channel estimate :math:`\hat{\mathbf{h}}^{(5)}` and corresponding
   1579     error variances :math:`\text{diag}\left( \mathbf{\Sigma}^{(5)} \right)` are
   1580     returned for every resource element :math:`(m,n)`.
   1581 
   1582     **Remark:** No scaling is performed after the last interpolation or smoothing
   1583     step.
   1584 
   1585     **Remark:** All passes assume that the estimation error covariance matrix
   1586     (:math:`\mathbf{\Sigma}`, :math:`\tilde{\mathbf{\Sigma}}^{(2)}`, or :math:`\tilde{\mathbf{\Sigma}}^{(4)}`) is diagonal, which
   1587     may not be accurate. When this assumption does not hold, this interpolator is only
   1588     an approximation of LMMSE interpolation.
   1589 
   1590     **Remark:** The order in which frequency interpolation, temporal
   1591     interpolation, and, optionally, spatial smoothing are applied, is controlled using the
   1592     ``order`` parameter.
   1593 
   1594     Note
   1595     ----
   1596     This layer does not support graph mode with XLA.
   1597 
   1598     Parameters
   1599     ----------
   1600     pilot_pattern : PilotPattern
   1601         An instance of :class:`~sionna.ofdm.PilotPattern`
   1602 
   1603     cov_mat_time : [num_ofdm_symbols, num_ofdm_symbols], tf.complex
   1604         Time covariance matrix of the channel
   1605 
   1606     cov_mat_freq : [fft_size, fft_size], tf.complex
   1607         Frequency covariance matrix of the channel
   1608 
   1609     cov_time_space : [num_rx_ant, num_rx_ant], tf.complex
   1610         Spatial covariance matrix of the channel.
   1611         Defaults to `None`.
   1612         Only required if spatial smoothing is requested (see ``order``).
   1613 
   1614     order : str
   1615         Order in which to perform interpolation and optional smoothing.
   1616         For example, ``"t-f-s"`` means that interpolation across the OFDM symbols
   1617         is performed first (``"t"``: time), followed by interpolation across the
   1618         sub-carriers (``"f"``: frequency), and finally smoothing across the
   1619         receive antennas (``"s"``: space).
   1620         Similarly, ``"f-t"`` means interpolation across the sub-carriers followed
   1621         by interpolation across the OFDM symbols and no spatial smoothing.
   1622         The spatial covariance matrix (``cov_time_space``) is only required when
   1623         spatial smoothing is requested.
   1624         Time and frequency interpolation are not optional to ensure that a channel
   1625         estimate is computed for all resource elements.
   1626 
   1627     Input
   1628     -----
   1629     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
   1630         Channel estimates for the pilot-carrying resource elements
   1631 
   1632     err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex
   1633         Channel estimation error variances for the pilot-carrying resource elements
   1634 
   1635     Output
   1636     ------
   1637     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
   1638         Channel estimates accross the entire resource grid for all
   1639         transmitters and streams
   1640 
   1641     err_var : Same shape as ``h_hat``, tf.float
   1642         Channel estimation error variances accross the entire resource grid
   1643         for all transmitters and streams
   1644     """
   1645 
   1646     def __init__(self, pilot_pattern, cov_mat_time, cov_mat_freq,
   1647                     cov_mat_space=None, order='t-f'):
   1648 
   1649         # Check the specified order
   1650         order = order.split('-')
   1651         assert 2 <= len(order) <= 3, "Invalid order for interpolation."
   1652         spatial_smoothing = False
   1653         freq_smoothing = False
   1654         time_smoothing = False
   1655         for o in order:
   1656             assert o in ('s', 'f', 't'), f"Uknown dimension {o}"
   1657             if o == 's':
   1658                 assert not spatial_smoothing,\
   1659                     "Spatial smoothing can be specified at most once"
   1660                 spatial_smoothing = True
   1661             elif o == 't':
   1662                 assert not time_smoothing,\
   1663                     "Temporal interpolation can be specified once only"
   1664                 time_smoothing = True
   1665             elif o == 'f':
   1666                 assert not freq_smoothing,\
   1667                     "Frequency interpolation can be specified once only"
   1668                 freq_smoothing = True
   1669         if spatial_smoothing:
   1670             assert cov_mat_space is not None,\
   1671                 "A spatial covariance matrix is required for spatial smoothing"
   1672         assert freq_smoothing, "Frequency interpolation is required"
   1673         assert time_smoothing, "Time interpolation is required"
   1674 
   1675         self._order = order
   1676         self._num_ofdm_symbols = pilot_pattern.num_ofdm_symbols
   1677         self._num_effective_subcarriers =pilot_pattern.num_effective_subcarriers
   1678 
   1679         # Build pilot masks for every stream
   1680         pilot_mask = self._build_pilot_mask(pilot_pattern)
   1681 
   1682         # Build indices for mapping channel estimates and
   1683         # error variances that are given as input to a
   1684         # resource grid
   1685         num_pilots = pilot_pattern.pilots.shape[2]
   1686         inputs_to_rg_indices = self._build_inputs2rg_indices(pilot_mask,
   1687                                                              num_pilots)
   1688         self._inputs_to_rg_indices = tf.cast(inputs_to_rg_indices, tf.int32)
   1689 
   1690         # 1D interpolator according to requested order
   1691         # Interpolation is always performed along the inner dimension.
   1692         interpolators = []
   1693         # Masks for masking error variances that were not updated
   1694         err_var_masks = []
   1695         for i, o in enumerate(order):
   1696             # Is it the last one?
   1697             last_step = i == len(order)-1
   1698             # Frequency
   1699             if o == "f":
   1700                 interpolator = LMMSEInterpolator1D(pilot_mask, cov_mat_freq,
   1701                                                         last_step=last_step)
   1702                 pilot_mask = self._update_pilot_mask_interp(pilot_mask)
   1703                 err_var_mask = tf.cast(pilot_mask == 1,
   1704                                         cov_mat_freq.dtype.real_dtype)
   1705             # Time
   1706             elif o == 't':
   1707                 pilot_mask = tf.transpose(pilot_mask, [0, 1, 3, 2])
   1708                 interpolator = LMMSEInterpolator1D(pilot_mask, cov_mat_time,
   1709                                                         last_step=last_step)
   1710                 pilot_mask = self._update_pilot_mask_interp(pilot_mask)
   1711                 pilot_mask = tf.transpose(pilot_mask, [0, 1, 3, 2])
   1712                 err_var_mask = tf.cast(pilot_mask == 1,
   1713                                             cov_mat_freq.dtype.real_dtype)
   1714             # Space
   1715             else:
   1716                 interpolator = SpatialChannelFilter(cov_mat_space,
   1717                                                     last_step=last_step)
   1718                 err_var_mask = tf.cast(pilot_mask == 1,
   1719                                             cov_mat_freq.dtype.real_dtype)
   1720             interpolators.append(interpolator)
   1721             err_var_masks.append(err_var_mask)
   1722         self._interpolators = interpolators
   1723         self._err_var_masks = err_var_masks
   1724 
   1725     def _build_pilot_mask(self, pilot_pattern):
   1726         """
   1727         Build for every transmitter and stream a pilot mask indicating
   1728         which REs are allocated to pilots, data, or not used.
   1729         # 0 -> Data
   1730         # 1 -> Pilot
   1731         # 2 -> Not used
   1732         """
   1733 
   1734         mask = pilot_pattern.mask
   1735         pilots = pilot_pattern.pilots
   1736         num_tx = mask.shape[0]
   1737         num_streams_per_tx = mask.shape[1]
   1738         num_ofdm_symbols = mask.shape[2]
   1739         num_effective_subcarriers = mask.shape[3]
   1740 
   1741         pilot_mask = np.zeros([num_tx, num_streams_per_tx, num_ofdm_symbols,
   1742                                 num_effective_subcarriers], int)
   1743         for tx,st in itertools.product( range(num_tx),
   1744                                         range(num_streams_per_tx)):
   1745             pil_index = 0
   1746             for sb,sc in itertools.product( range(num_ofdm_symbols),
   1747                                             range(num_effective_subcarriers)):
   1748                 if mask[tx,st,sb,sc] == 1:
   1749                     if np.abs(pilots[tx,st,pil_index]) > 0.0:
   1750                         pilot_mask[tx,st,sb,sc] = 1
   1751                     else:
   1752                         pilot_mask[tx,st,sb,sc] = 2
   1753                     pil_index += 1
   1754 
   1755         return pilot_mask
   1756 
   1757     def _build_inputs2rg_indices(self, pilot_mask, num_pilots):
   1758         """
   1759         Builds indices for mapping channel estimates and
   1760         error variances that are given as input to a
   1761         resource grid
   1762         """
   1763 
   1764         num_tx = pilot_mask.shape[0]
   1765         num_streams_per_tx = pilot_mask.shape[1]
   1766         num_ofdm_symbols = pilot_mask.shape[2]
   1767         num_effective_subcarriers = pilot_mask.shape[3]
   1768 
   1769         inputs_to_rg_indices = np.zeros([num_tx, num_streams_per_tx,
   1770                                          num_pilots, 4], int)
   1771         for tx,st in itertools.product( range(num_tx),
   1772                                         range(num_streams_per_tx)):
   1773             pil_index = 0 # Pilot index for this stream
   1774             for sb,sc in itertools.product( range(num_ofdm_symbols),
   1775                                             range(num_effective_subcarriers)):
   1776                 if pilot_mask[tx,st,sb,sc] == 0:
   1777                     continue
   1778                 if pilot_mask[tx,st,sb,sc] == 1:
   1779                     inputs_to_rg_indices[tx, st, pil_index] = [tx, st, sb, sc]
   1780                 pil_index += 1
   1781 
   1782         return inputs_to_rg_indices
   1783 
   1784     def _update_pilot_mask_interp(self, pilot_mask):
   1785         """
   1786         Update the pilot mask to label the resource elements for which the
   1787         channel was interpolated.
   1788         """
   1789 
   1790         interpolated = np.any(pilot_mask == 1, axis=-1, keepdims=True)
   1791         pilot_mask = np.where(interpolated, 1, pilot_mask)
   1792 
   1793         return pilot_mask
   1794 
   1795     def __call__(self, h_hat, err_var):
   1796 
   1797         # h_hat : [batch_size, num_rx, num_rx_ant, num_tx,
   1798         #          num_streams_per_tx, num_pilots]
   1799         # err_var : [batch_size, num_rx, num_rx_ant, num_tx,
   1800         #          num_streams_per_tx, num_pilots]
   1801 
   1802         batch_size = tf.shape(h_hat)[0]
   1803         num_rx = tf.shape(h_hat)[1]
   1804         num_rx_ant = tf.shape(h_hat)[2]
   1805         num_tx = tf.shape(h_hat)[3]
   1806         num_tx_stream = tf.shape(h_hat)[4]
   1807         num_ofdm_symbols = self._num_ofdm_symbols
   1808         num_effective_subcarriers = self._num_effective_subcarriers
   1809 
   1810         # For some estimator, err_var might not have the same shape
   1811         # as h_hat
   1812         err_var = tf.broadcast_to(err_var, tf.shape(h_hat))
   1813 
   1814         # Mapping the channel estimates and error variances to a resource grid
   1815         # all : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1816         #           num_ofdm_symbols, num_effective_subcarriers]
   1817         h_hat = tf.transpose(h_hat, [3, 4, 5, 0, 1, 2])
   1818         err_var = tf.transpose(err_var, [3, 4, 5, 0, 1, 2])
   1819         h_hat = tf.scatter_nd(self._inputs_to_rg_indices, h_hat,
   1820                                             [num_tx, num_tx_stream,
   1821                                              num_ofdm_symbols,
   1822                                              num_effective_subcarriers,
   1823                                              batch_size, num_rx, num_rx_ant])
   1824         err_var = tf.scatter_nd(self._inputs_to_rg_indices, err_var,
   1825                                             [num_tx, num_tx_stream,
   1826                                              num_ofdm_symbols,
   1827                                              num_effective_subcarriers,
   1828                                              batch_size, num_rx, num_rx_ant])
   1829         h_hat = tf.transpose(h_hat, [4, 5, 6, 0, 1, 2, 3])
   1830         err_var = tf.transpose(err_var, [4, 5, 6, 0, 1, 2, 3])
   1831 
   1832         # Interpolation
   1833         # Performed according to the requested order. Transpose are used as
   1834         # 1D interpolation is performed along the inner axis.
   1835         items = zip(self._order, self._interpolators, self._err_var_masks)
   1836         for o,interp,err_var_mask in items:
   1837             # Frequency
   1838             if o == 'f':
   1839                 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1840                 #           num_ofdm_symbols, num_effective_subcarriers]
   1841                 h_hat, err_var = interp(h_hat, err_var)
   1842                 err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0)
   1843                 err_var = err_var*err_var_mask
   1844             # Time
   1845             elif o == 't':
   1846                 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1847                 #           num_effective_subcarriers, num_ofdm_symbols]
   1848                 h_hat = tf.transpose(h_hat, [0, 1, 2, 3, 4, 6, 5])
   1849                 err_var = tf.transpose(err_var, [0, 1, 2, 3, 4, 6, 5])
   1850                 h_hat, err_var = interp(h_hat, err_var)
   1851                 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,
   1852                 #           num_ofdm_symbols, num_effective_subcarriers]
   1853                 h_hat = tf.transpose(h_hat, [0, 1, 2, 3, 4, 6, 5])
   1854                 err_var = tf.transpose(err_var, [0, 1, 2, 3, 4, 6, 5])
   1855                 err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0)
   1856                 err_var = err_var*err_var_mask
   1857             # Space
   1858             elif o == 's':
   1859                 # [batch_size, num_rx, num_tx, num_streams_per_tx,
   1860                 #      num_ofdm_symbols, num_effective_subcarriers, num_rx_ant]
   1861                 h_hat = tf.transpose(h_hat, [0, 1, 3, 4, 5, 6, 2])
   1862                 err_var = tf.transpose(err_var, [0, 1, 3, 4, 5, 6, 2])
   1863                 h_hat, err_var = interp(h_hat, err_var)
   1864                 # [batch_size, num_rx, num_tx, num_streams_per_tx,
   1865                 #      num_ofdm_symbols, num_effective_subcarriers, num_rx_ant]
   1866                 h_hat = tf.transpose(h_hat, [0, 1, 6, 2, 3, 4, 5])
   1867                 err_var = tf.transpose(err_var, [0, 1, 6, 2, 3, 4, 5])
   1868                 err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0)
   1869                 err_var = err_var*err_var_mask
   1870 
   1871         return h_hat, err_var
   1872 
   1873 #######################################################
   1874 # Utilities
   1875 #######################################################
   1876 
   1877 def tdl_freq_cov_mat(model, subcarrier_spacing, fft_size, delay_spread,
   1878                         dtype=tf.complex64):
   1879     # pylint: disable=line-too-long
   1880     r"""
   1881     Computes the frequency covariance matrix of a
   1882     :class:`~sionna.channel.tr38901.TDL` channel model.
   1883 
   1884     The channel frequency covariance matrix :math:`\mathbf{R}^{(f)}` of a TDL channel model is
   1885 
   1886     .. math::
   1887         \mathbf{R}^{(f)}_{u,v} = \sum_{\ell=1}^L P_\ell e^{-j 2 \pi \tau_\ell \Delta_f (u-v)}, 1 \leq u,v \leq M
   1888 
   1889     where :math:`M` is the FFT size, :math:`L` is the number of paths for the selected TDL model,
   1890     :math:`P_\ell` and :math:`\tau_\ell` are the average power and delay for the
   1891     :math:`\ell^{\text{th}}` path, respectively, and :math:`\Delta_f` is the sub-carrier spacing.
   1892 
   1893     Input
   1894     ------
   1895     model : str
   1896         TDL model for which to return the covariance matrix.
   1897         Should be one of "A", "B", "C", "D", or "E".
   1898 
   1899     subcarrier_spacing : float
   1900         Sub-carrier spacing [Hz]
   1901 
   1902     fft_size : float
   1903         FFT size
   1904 
   1905     delay_spread : float
   1906         Delay spread [s]
   1907 
   1908     dtype : tf.DType
   1909         Datatype to use for the output.
   1910         Should be one of `tf.complex64` or `tf.complex128`.
   1911         Defaults to `tf.complex64`.
   1912 
   1913     Output
   1914     ------
   1915         cov_mat : [fft_size, fft_size], tf.complex
   1916             Channel frequency covariance matrix
   1917     """
   1918 
   1919     assert dtype in (tf.complex64, tf.complex128),\
   1920         "The `dtype` should be a complex datatype"
   1921 
   1922     #
   1923     # Load the power delay profile
   1924     #
   1925 
   1926     # Set the file from which to load the model
   1927     assert model in ('A', 'B', 'C', 'D', 'E'), "Invalid TDL model"
   1928     if model == 'A':
   1929         parameters_fname = "TDL-A.json"
   1930     elif model == 'B':
   1931         parameters_fname = "TDL-B.json"
   1932     elif model == 'C':
   1933         parameters_fname = "TDL-C.json"
   1934     elif model == 'D':
   1935         parameters_fname = "TDL-D.json"
   1936     else: # 'E'
   1937         parameters_fname = "TDL-E.json"
   1938     source = files(models).joinpath(parameters_fname)
   1939     # pylint: disable=unspecified-encoding
   1940     with open(source) as parameter_file:
   1941         params = json.load(parameter_file)
   1942     # LoS scenario ?
   1943     los = bool(params['los'])
   1944     # Retrieve power and delays
   1945     delays = np.array(params['delays'])*delay_spread
   1946     mean_powers = np.power(10.0, np.array(params['powers'])/10.0)
   1947 
   1948     if los:
   1949         # Add the power of the specular and non-specular component of
   1950         # the first path
   1951         mean_powers[0] = mean_powers[0] + mean_powers[1]
   1952         mean_powers = np.concatenate([mean_powers[:1], mean_powers[2:]], axis=0)
   1953         # The first two paths have 0 delays as they correspond to the
   1954         # specular and reflected components of the first path.
   1955         delays = delays[1:]
   1956 
   1957     # Normalize the PDP
   1958     norm_factor = np.sum(mean_powers)
   1959     mean_powers = mean_powers / norm_factor
   1960 
   1961     #
   1962     # Build frequency covariance matrix
   1963     #
   1964 
   1965     n = np.arange(fft_size)
   1966     p = -2.*np.pi*subcarrier_spacing*n
   1967     p = np.expand_dims(p, axis=0)
   1968     delays = np.expand_dims(delays, axis=1)
   1969     p = p*delays
   1970     p = np.exp(1j*p)
   1971     p = np.expand_dims(p, axis=-1)
   1972     cov_mat = np.matmul(p, np.transpose(np.conj(p), [0, 2, 1]))
   1973     mean_powers = np.expand_dims(mean_powers, axis=(1,2))
   1974     cov_mat = np.sum(mean_powers*cov_mat, axis=0)
   1975 
   1976     return tf.cast(cov_mat, dtype)
   1977 
   1978 def tdl_time_cov_mat(model, speed, carrier_frequency, ofdm_symbol_duration,
   1979         num_ofdm_symbols, los_angle_of_arrival=PI/4., dtype=tf.complex64):
   1980     # pylint: disable=line-too-long
   1981     r"""
   1982     Computes the time covariance matrix of a
   1983     :class:`~sionna.channel.tr38901.TDL` channel model.
   1984 
   1985     For non-line-of-sight (NLoS) model, the channel time covariance matrix
   1986     :math:`\mathbf{R^{(t)}}` of a TDL channel model is
   1987 
   1988     .. math::
   1989         \mathbf{R^{(t)}}_{u,v} = J_0 \left( \nu \Delta_t \left( u-v \right) \right)
   1990 
   1991     where :math:`J_0` is the zero-order Bessel function of the first kind,
   1992     :math:`\Delta_t` the duration of an OFDM symbol, and :math:`\nu` the Doppler
   1993     spread defined by
   1994 
   1995     .. math::
   1996         \nu = 2 \pi \frac{v}{c} f_c
   1997 
   1998     where :math:`v` is the movement speed, :math:`c` the speed of light, and
   1999     :math:`f_c` the carrier frequency.
   2000 
   2001     For line-of-sight (LoS) channel models, the channel time covariance matrix
   2002     is
   2003 
   2004     .. math::
   2005         \mathbf{R^{(t)}}_{u,v} = P_{\text{NLoS}} J_0 \left( \nu \Delta_t \left( u-v \right) \right) + P_{\text{LoS}}e^{j \nu \Delta_t \left( u-v \right) \cos{\alpha_{\text{LoS}}}}
   2006 
   2007     where :math:`\alpha_{\text{LoS}}` is the angle-of-arrival for the LoS path,
   2008     :math:`P_{\text{NLoS}}` the total power of NLoS paths, and
   2009     :math:`P_{\text{LoS}}` the power of the LoS path. The power delay profile
   2010     is assumed to have unit power, i.e., :math:`P_{\text{NLoS}} + P_{\text{LoS}} = 1`.
   2011 
   2012     Input
   2013     ------
   2014     model : str
   2015         TDL model for which to return the covariance matrix.
   2016         Should be one of "A", "B", "C", "D", or "E".
   2017 
   2018     speed : float
   2019         Speed [m/s]
   2020 
   2021     carrier_frequency : float
   2022         Carrier frequency [Hz]
   2023 
   2024     ofdm_symbol_duration : float
   2025         Duration of an OFDM symbol [s]
   2026 
   2027     num_ofdm_symbols : int
   2028         Number of OFDM symbols
   2029 
   2030     los_angle_of_arrival : float
   2031         Angle-of-arrival for LoS path [radian]. Only used with LoS models.
   2032         Defaults to :math:`\pi/4`.
   2033 
   2034     dtype : tf.DType
   2035         Datatype to use for the output.
   2036         Should be one of `tf.complex64` or `tf.complex128`.
   2037         Defaults to `tf.complex64`.
   2038 
   2039     Output
   2040     ------
   2041         cov_mat : [num_ofdm_symbols, num_ofdm_symbols], tf.complex
   2042             Channel time covariance matrix
   2043     """
   2044 
   2045     # Doppler spread
   2046     doppler_spread = 2.*PI*speed/SPEED_OF_LIGHT*carrier_frequency
   2047 
   2048     #
   2049     # Load the power delay profile
   2050     #
   2051 
   2052     # Set the file from which to load the model
   2053     assert model in ('A', 'B', 'C', 'D', 'E'), "Invalid TDL model"
   2054     if model == 'A':
   2055         parameters_fname = "TDL-A.json"
   2056     elif model == 'B':
   2057         parameters_fname = "TDL-B.json"
   2058     elif model == 'C':
   2059         parameters_fname = "TDL-C.json"
   2060     elif model == 'D':
   2061         parameters_fname = "TDL-D.json"
   2062     else: # 'E'
   2063         parameters_fname = "TDL-E.json"
   2064     source = files(models).joinpath(parameters_fname)
   2065     # pylint: disable=unspecified-encoding
   2066     with open(source) as parameter_file:
   2067         params = json.load(parameter_file)
   2068     # LoS scenario ?
   2069     los = bool(params['los'])
   2070     # Retrieve power and delays
   2071     mean_powers = np.power(10.0, np.array(params['powers'])/10.0)
   2072 
   2073     # Normalize the PDP
   2074     norm_factor = np.sum(mean_powers)
   2075     mean_powers = mean_powers / norm_factor
   2076 
   2077     if los:
   2078         los_power = mean_powers[0]
   2079         nlos_power = np.sum(mean_powers[1:])
   2080     else:
   2081         nlos_power = np.sum(mean_powers)
   2082 
   2083     #
   2084     # Build time covariance matrix
   2085     #
   2086 
   2087     indices = np.arange(num_ofdm_symbols)
   2088     s1 = np.expand_dims(indices, axis=1)
   2089     s2 = np.expand_dims(indices, axis=0)
   2090     exp = doppler_spread*ofdm_symbol_duration*(s1-s2)
   2091     cov_mat_nlos = jv(0.0, exp)*nlos_power
   2092     if los:
   2093         cov_mat_los = np.exp(1j*exp*np.cos(los_angle_of_arrival))*los_power
   2094         cov_mat = cov_mat_nlos+cov_mat_los
   2095     else:
   2096         cov_mat = cov_mat_nlos
   2097 
   2098     return tf.cast(cov_mat, dtype)