channel_estimation.py (89915B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Functions related to OFDM channel estimation""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 import numpy as np 10 from sionna.channel.tr38901 import models 11 from sionna.utils import flatten_last_dims, expand_to_rank, matrix_inv 12 from sionna.ofdm import ResourceGrid, RemoveNulledSubcarriers 13 from sionna import PI, SPEED_OF_LIGHT 14 from scipy.special import jv 15 import itertools 16 from abc import ABC, abstractmethod 17 import json 18 from importlib_resources import files 19 20 class BaseChannelEstimator(ABC, Layer): 21 # pylint: disable=line-too-long 22 r"""BaseChannelEstimator(resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs) 23 24 Abstract layer for implementing an OFDM channel estimator. 25 26 Any layer that implements an OFDM channel estimator must implement this 27 class and its 28 :meth:`~sionna.ofdm.BaseChannelEstimator.estimate_at_pilot_locations` 29 abstract method. 30 31 This class extracts the pilots from the received resource grid ``y``, calls 32 the :meth:`~sionna.ofdm.BaseChannelEstimator.estimate_at_pilot_locations` 33 method to estimate the channel for the pilot-carrying resource elements, 34 and then interpolates the channel to compute channel estimates for the 35 data-carrying resouce elements using the interpolation method specified by 36 ``interpolation_type`` or the ``interpolator`` object. 37 38 Parameters 39 ---------- 40 resource_grid : ResourceGrid 41 An instance of :class:`~sionna.ofdm.ResourceGrid`. 42 43 interpolation_type : One of ["nn", "lin", "lin_time_avg"], string 44 The interpolation method to be used. 45 It is ignored if ``interpolator`` is not `None`. 46 Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`") 47 or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with 48 averaging across OFDM symbols (`"lin_time_avg"`). 49 Defaults to "nn". 50 51 interpolator : BaseChannelInterpolator 52 An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`, 53 such as :class:`~sionna.ofdm.LMMSEInterpolator`, 54 or `None`. In the latter case, the interpolator specfied 55 by ``interpolation_type`` is used. 56 Otherwise, the ``interpolator`` is used and ``interpolation_type`` 57 is ignored. 58 Defaults to `None`. 59 60 dtype : tf.Dtype 61 Datatype for internal calculations and the output dtype. 62 Defaults to `tf.complex64`. 63 64 Input 65 ----- 66 (y, no) : 67 Tuple: 68 69 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex 70 Observed resource grid 71 72 no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float 73 Variance of the AWGN 74 75 Output 76 ------ 77 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex 78 Channel estimates accross the entire resource grid for all 79 transmitters and streams 80 81 err_var : Same shape as ``h_hat``, tf.float 82 Channel estimation error variance accross the entire resource grid 83 for all transmitters and streams 84 """ 85 def __init__(self, resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs): 86 super().__init__(dtype=dtype, **kwargs) 87 88 assert isinstance(resource_grid, ResourceGrid),\ 89 "You must provide a valid instance of ResourceGrid." 90 self._pilot_pattern = resource_grid.pilot_pattern 91 self._removed_nulled_scs = RemoveNulledSubcarriers(resource_grid) 92 93 assert interpolation_type in ["nn","lin","lin_time_avg",None], \ 94 "Unsupported `interpolation_type`" 95 self._interpolation_type = interpolation_type 96 97 if interpolator is not None: 98 assert isinstance(interpolator, BaseChannelInterpolator), \ 99 "`interpolator` must implement the BaseChannelInterpolator interface" 100 self._interpol = interpolator 101 elif self._interpolation_type == "nn": 102 self._interpol = NearestNeighborInterpolator(self._pilot_pattern) 103 elif self._interpolation_type == "lin": 104 self._interpol = LinearInterpolator(self._pilot_pattern) 105 elif self._interpolation_type == "lin_time_avg": 106 self._interpol = LinearInterpolator(self._pilot_pattern, 107 time_avg=True) 108 109 # Precompute indices to gather received pilot signals 110 num_pilot_symbols = self._pilot_pattern.num_pilot_symbols 111 mask = flatten_last_dims(self._pilot_pattern.mask) 112 pilot_ind = tf.argsort(mask, axis=-1, direction="DESCENDING") 113 self._pilot_ind = pilot_ind[...,:num_pilot_symbols] 114 115 @abstractmethod 116 def estimate_at_pilot_locations(self, y_pilots, no): 117 """ 118 Estimates the channel for the pilot-carrying resource elements. 119 120 This is an abstract method that must be implemented by a concrete 121 OFDM channel estimator that implement this class. 122 123 Input 124 ----- 125 y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, num_pilot_symbols], tf.complex 126 Observed signals for the pilot-carrying resource elements 127 128 no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float 129 Variance of the AWGN 130 131 Output 132 ------ 133 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, num_pilot_symbols], tf.complex 134 Channel estimates for the pilot-carrying resource elements 135 136 err_var : Same shape as ``h_hat``, tf.float 137 Channel estimation error variance for the pilot-carrying 138 resource elements 139 """ 140 pass 141 142 def call(self, inputs): 143 144 y, no = inputs 145 146 # y has shape: 147 # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,.. 148 # ... fft_size] 149 # 150 # no can have shapes [], [batch_size], [batch_size, num_rx] 151 # or [batch_size, num_rx, num_rx_ant] 152 153 # Removed nulled subcarriers (guards, dc) 154 y_eff = self._removed_nulled_scs(y) 155 156 # Flatten the resource grid for pilot extraction 157 # New shape: [...,num_ofdm_symbols*num_effective_subcarriers] 158 y_eff_flat = flatten_last_dims(y_eff) 159 160 # Gather pilots along the last dimensions 161 # Resulting shape: y_eff_flat.shape[:-1] + pilot_ind.shape, i.e.: 162 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... 163 # ..., num_pilot_symbols] 164 y_pilots = tf.gather(y_eff_flat, self._pilot_ind, axis=-1) 165 166 # Compute LS channel estimates 167 # Note: Some might be Inf because pilots=0, but we do not care 168 # as only the valid estimates will be considered during interpolation. 169 # We do a save division to replace Inf by 0. 170 # Broadcasting from pilots here is automatic since pilots have shape 171 # [num_tx, num_streams, num_pilot_symbols] 172 h_hat, err_var = self.estimate_at_pilot_locations(y_pilots, no) 173 174 # Interpolate channel estimates over the resource grid 175 if self._interpolation_type is not None: 176 h_hat, err_var = self._interpol(h_hat, err_var) 177 err_var = tf.maximum(err_var, tf.cast(0, err_var.dtype)) 178 179 return h_hat, err_var 180 181 182 class LSChannelEstimator(BaseChannelEstimator, Layer): 183 # pylint: disable=line-too-long 184 r"""LSChannelEstimator(resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs) 185 186 Layer implementing least-squares (LS) channel estimation for OFDM MIMO systems. 187 188 After LS channel estimation at the pilot positions, the channel estimates 189 and error variances are interpolated accross the entire resource grid using 190 a specified interpolation function. 191 192 For simplicity, the underlying algorithm is described for a vectorized observation, 193 where we have a nonzero pilot for all elements to be estimated. 194 The actual implementation works on a full OFDM resource grid with sparse 195 pilot patterns. The following model is assumed: 196 197 .. math:: 198 199 \mathbf{y} = \mathbf{h}\odot\mathbf{p} + \mathbf{n} 200 201 where :math:`\mathbf{y}\in\mathbb{C}^{M}` is the received signal vector, 202 :math:`\mathbf{p}\in\mathbb{C}^M` is the vector of pilot symbols, 203 :math:`\mathbf{h}\in\mathbb{C}^{M}` is the channel vector to be estimated, 204 and :math:`\mathbf{n}\in\mathbb{C}^M` is a zero-mean noise vector whose 205 elements have variance :math:`N_0`. The operator :math:`\odot` denotes 206 element-wise multiplication. 207 208 The channel estimate :math:`\hat{\mathbf{h}}` and error variances 209 :math:`\sigma^2_i`, :math:`i=0,\dots,M-1`, are computed as 210 211 .. math:: 212 213 \hat{\mathbf{h}} &= \mathbf{y} \odot 214 \frac{\mathbf{p}^\star}{\left|\mathbf{p}\right|^2} 215 = \mathbf{h} + \tilde{\mathbf{h}}\\ 216 \sigma^2_i &= \mathbb{E}\left[\tilde{h}_i \tilde{h}_i^\star \right] 217 = \frac{N_0}{\left|p_i\right|^2}. 218 219 The channel estimates and error variances are then interpolated accross 220 the entire resource grid. 221 222 Parameters 223 ---------- 224 resource_grid : ResourceGrid 225 An instance of :class:`~sionna.ofdm.ResourceGrid`. 226 227 interpolation_type : One of ["nn", "lin", "lin_time_avg"], string 228 The interpolation method to be used. 229 It is ignored if ``interpolator`` is not `None`. 230 Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`") 231 or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with 232 averaging across OFDM symbols (`"lin_time_avg"`). 233 Defaults to "nn". 234 235 interpolator : BaseChannelInterpolator 236 An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`, 237 such as :class:`~sionna.ofdm.LMMSEInterpolator`, 238 or `None`. In the latter case, the interpolator specfied 239 by ``interpolation_type`` is used. 240 Otherwise, the ``interpolator`` is used and ``interpolation_type`` 241 is ignored. 242 Defaults to `None`. 243 244 dtype : tf.Dtype 245 Datatype for internal calculations and the output dtype. 246 Defaults to `tf.complex64`. 247 248 Input 249 ----- 250 (y, no) : 251 Tuple: 252 253 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex 254 Observed resource grid 255 256 no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float 257 Variance of the AWGN 258 259 Output 260 ------ 261 h_ls : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex 262 Channel estimates accross the entire resource grid for all 263 transmitters and streams 264 265 err_var : Same shape as ``h_ls``, tf.float 266 Channel estimation error variance accross the entire resource grid 267 for all transmitters and streams 268 """ 269 270 def estimate_at_pilot_locations(self, y_pilots, no): 271 272 # y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, 273 # num_pilot_symbols], tf.complex 274 # The observed signals for the pilot-carrying resource elements. 275 276 # no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, 277 # tf.float 278 # The variance of the AWGN. 279 280 # Compute LS channel estimates 281 # Note: Some might be Inf because pilots=0, but we do not care 282 # as only the valid estimates will be considered during interpolation. 283 # We do a save division to replace Inf by 0. 284 # Broadcasting from pilots here is automatic since pilots have shape 285 # [num_tx, num_streams, num_pilot_symbols] 286 h_ls = tf.math.divide_no_nan(y_pilots, self._pilot_pattern.pilots) 287 288 # Compute error variance and broadcast to the same shape as h_ls 289 # Expand rank of no for broadcasting 290 no = expand_to_rank(no, tf.rank(h_ls), -1) 291 292 # Expand rank of pilots for broadcasting 293 pilots = expand_to_rank(self._pilot_pattern.pilots, tf.rank(h_ls), 0) 294 295 # Compute error variance, broadcastable to the shape of h_ls 296 err_var = tf.math.divide_no_nan(no, tf.abs(pilots)**2) 297 298 return h_ls, err_var 299 300 301 class BaseChannelInterpolator(ABC): 302 # pylint: disable=line-too-long 303 r"""BaseChannelInterpolator() 304 305 Abstract layer for implementing an OFDM channel interpolator. 306 307 Any layer that implements an OFDM channel interpolator must implement this 308 callable class. 309 310 A channel interpolator is used by an OFDM channel estimator 311 (:class:`~sionna.ofdm.BaseChannelEstimator`) to compute channel estimates 312 for the data-carrying resource elements from the channel estimates for the 313 pilot-carrying resource elements. 314 315 Input 316 ----- 317 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 318 Channel estimates for the pilot-carrying resource elements 319 320 err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 321 Channel estimation error variances for the pilot-carrying resource elements 322 323 Output 324 ------ 325 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex 326 Channel estimates accross the entire resource grid for all 327 transmitters and streams 328 329 err_var : Same shape as ``h_hat``, tf.float 330 Channel estimation error variance accross the entire resource grid 331 for all transmitters and streams 332 """ 333 334 @abstractmethod 335 def __call__(self, h_hat, err_var): 336 pass 337 338 339 class NearestNeighborInterpolator(BaseChannelInterpolator): 340 # pylint: disable=line-too-long 341 r"""NearestNeighborInterpolator(pilot_pattern) 342 343 Nearest-neighbor channel estimate interpolation on a resource grid. 344 345 This class assigns to each element of an OFDM resource grid one of 346 ``num_pilots`` provided channel estimates and error 347 variances according to the nearest neighbor method. It is assumed 348 that the measurements were taken at the nonzero positions of a 349 :class:`~sionna.ofdm.PilotPattern`. 350 351 The figure below shows how four channel estimates are interpolated 352 accross a resource grid. Grey fields indicate measurement positions 353 while the colored regions show which resource elements are assigned 354 to the same measurement value. 355 356 .. image:: ../figures/nearest_neighbor_interpolation.png 357 358 Parameters 359 ---------- 360 pilot_pattern : PilotPattern 361 An instance of :class:`~sionna.ofdm.PilotPattern` 362 363 Input 364 ----- 365 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 366 Channel estimates for the pilot-carrying resource elements 367 368 err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 369 Channel estimation error variances for the pilot-carrying resource elements 370 371 Output 372 ------ 373 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex 374 Channel estimates accross the entire resource grid for all 375 transmitters and streams 376 377 err_var : Same shape as ``h_hat``, tf.float 378 Channel estimation error variances accross the entire resource grid 379 for all transmitters and streams 380 """ 381 def __init__(self, pilot_pattern): 382 super().__init__() 383 384 assert(pilot_pattern.num_pilot_symbols>0),\ 385 """The pilot pattern cannot be empty""" 386 387 # Reshape mask to shape [-1,num_ofdm_symbols,num_effective_subcarriers] 388 mask = np.array(pilot_pattern.mask) 389 mask_shape = mask.shape # Store to reconstruct the original shape 390 mask = np.reshape(mask, [-1] + list(mask_shape[-2:])) 391 392 # Reshape the pilots to shape [-1, num_pilot_symbols] 393 pilots = pilot_pattern.pilots 394 pilots = np.reshape(pilots, [-1] + [pilots.shape[-1]]) 395 396 max_num_zero_pilots = np.max(np.sum(np.abs(pilots)==0, -1)) 397 assert max_num_zero_pilots<pilots.shape[-1],\ 398 """Each pilot sequence must have at least one nonzero entry""" 399 400 # Compute gather indices for nearest neighbor interpolation 401 gather_ind = np.zeros_like(mask, dtype=np.int32) 402 for a in range(gather_ind.shape[0]): # For each pilot pattern... 403 i_p, j_p = np.where(mask[a]) # ...determine the pilot indices 404 405 for i in range(mask_shape[-2]): # Iterate over... 406 for j in range(mask_shape[-1]): # ... all resource elements 407 408 # Compute Manhattan distance to all pilot positions 409 d = np.abs(i-i_p) + np.abs(j-j_p) 410 411 # Set the distance at all pilot positions with zero energy 412 # equal to the maximum possible distance 413 d[np.abs(pilots[a])==0] = np.sum(mask_shape[-2:]) 414 415 # Find the pilot index with the shortest distance... 416 ind = np.argmin(d) 417 418 # ... and store it in the index tensor 419 gather_ind[a, i, j] = ind 420 421 # Reshape to the original shape of the mask, i.e.: 422 # [num_tx, num_streams_per_tx, num_ofdm_symbols,... 423 # ..., num_effective_subcarriers] 424 self._gather_ind = tf.reshape(gather_ind, mask_shape) 425 426 def _interpolate(self, inputs): 427 # inputs has shape: 428 # [k, l, m, num_tx, num_streams_per_tx, num_pilots] 429 430 # Transpose inputs to bring batch_dims for gather last. New shape: 431 # [num_tx, num_streams_per_tx, num_pilots, k, l, m] 432 perm = tf.roll(tf.range(tf.rank(inputs)), -3, 0) 433 inputs = tf.transpose(inputs, perm) 434 435 # Interpolate through gather. Shape: 436 # [num_tx, num_streams_per_tx, num_ofdm_symbols, 437 # ..., num_effective_subcarriers, k, l, m] 438 outputs = tf.gather(inputs, self._gather_ind, 2, batch_dims=2) 439 440 # Transpose outputs to bring batch_dims first again. New shape: 441 # [k, l, m, num_tx, num_streams_per_tx,... 442 # ..., num_ofdm_symbols, num_effective_subcarriers] 443 perm = tf.roll(tf.range(tf.rank(outputs)), 3, 0) 444 outputs = tf.transpose(outputs, perm) 445 446 return outputs 447 448 def __call__(self, h_hat, err_var): 449 450 h_hat = self._interpolate(h_hat) 451 err_var = self._interpolate(err_var) 452 return h_hat, err_var 453 454 455 class LinearInterpolator(BaseChannelInterpolator): 456 # pylint: disable=line-too-long 457 r"""LinearInterpolator(pilot_pattern, time_avg=False) 458 459 Linear channel estimate interpolation on a resource grid. 460 461 This class computes for each element of an OFDM resource grid 462 a channel estimate based on ``num_pilots`` provided channel estimates and 463 error variances through linear interpolation. 464 It is assumed that the measurements were taken at the nonzero positions 465 of a :class:`~sionna.ofdm.PilotPattern`. 466 467 The interpolation is done first across sub-carriers and then 468 across OFDM symbols. 469 470 Parameters 471 ---------- 472 pilot_pattern : PilotPattern 473 An instance of :class:`~sionna.ofdm.PilotPattern` 474 475 time_avg : bool 476 If enabled, measurements will be averaged across OFDM symbols 477 (i.e., time). This is useful for channels that do not vary 478 substantially over the duration of an OFDM frame. Defaults to `False`. 479 480 Input 481 ----- 482 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 483 Channel estimates for the pilot-carrying resource elements 484 485 err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 486 Channel estimation error variances for the pilot-carrying resource elements 487 488 Output 489 ------ 490 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex 491 Channel estimates accross the entire resource grid for all 492 transmitters and streams 493 494 err_var : Same shape as ``h_hat``, tf.float 495 Channel estimation error variances accross the entire resource grid 496 for all transmitters and streams 497 """ 498 def __init__(self, pilot_pattern, time_avg=False): 499 super().__init__() 500 501 assert(pilot_pattern.num_pilot_symbols>0),\ 502 """The pilot pattern cannot be empty""" 503 504 self._time_avg = time_avg 505 506 # Reshape mask to shape [-1,num_ofdm_symbols,num_effective_subcarriers] 507 mask = np.array(pilot_pattern.mask) 508 mask_shape = mask.shape # Store to reconstruct the original shape 509 mask = np.reshape(mask, [-1] + list(mask_shape[-2:])) 510 511 # Reshape the pilots to shape [-1, num_pilot_symbols] 512 pilots = pilot_pattern.pilots 513 pilots = np.reshape(pilots, [-1] + [pilots.shape[-1]]) 514 515 max_num_zero_pilots = np.max(np.sum(np.abs(pilots)==0, -1)) 516 assert max_num_zero_pilots<pilots.shape[-1],\ 517 """Each pilot sequence must have at least one nonzero entry""" 518 519 # Create actual pilot patterns for each stream over the resource grid 520 z = np.zeros_like(mask, dtype=pilots.dtype) 521 for a in range(z.shape[0]): 522 z[a][np.where(mask[a])] = pilots[a] 523 524 # Linear interpolation works as follows: 525 # We compute for each resource element (RE) 526 # x_0 : The x-value (i.e., sub-carrier index or OFDM symbol) at which 527 # the first channel measurement was taken 528 # x_1 : The x-value (i.e., sub-carrier index or OFDM symbol) at which 529 # the second channel measurement was taken 530 # y_0 : The first channel estimate 531 # y_1 : The second channel estimate 532 # x : The x-value (i.e., sub-carrier index or OFDM symbol) 533 # 534 # The linearly interpolated value y is then given as: 535 # y = (x-x_0) * (y_1-y_0) / (x_1-x_0) + y_0 536 # 537 # The following code pre-computes various quantities and indices 538 # that are needed to compute x_0, x_1, y_0, y_1, x for frequency- and 539 # time-domain interpolation. 540 541 ## 542 ## Frequency-domain interpolation 543 ## 544 self._x_freq = tf.cast(expand_to_rank(tf.range(0, mask.shape[-1]), 545 7, 546 axis=0), 547 pilots.dtype) 548 549 # Permutation indices to shift batch_dims last during gather 550 self._perm_fwd_freq = tf.roll(tf.range(6), -3, 0) 551 552 x_0_freq = np.zeros_like(mask, np.int32) 553 x_1_freq = np.zeros_like(mask, np.int32) 554 555 # Set REs of OFDM symbols without any pilot equal to -1 (dummy value) 556 x_0_freq[np.sum(np.abs(z), axis=-1)==0] = -1 557 x_1_freq[np.sum(np.abs(z), axis=-1)==0] = -1 558 559 y_0_freq_ind = np.copy(x_0_freq) # Indices used to gather estimates 560 y_1_freq_ind = np.copy(x_1_freq) # Indices used to gather estimates 561 562 # For each stream 563 for a in range(z.shape[0]): 564 565 pilot_count = 0 # Counts the number of non-zero pilots 566 567 # Indices of non-zero pilots within the pilots vector 568 pilot_ind = np.where(np.abs(pilots[a]))[0] 569 570 # Go through all OFDM symbols 571 for i in range(x_0_freq.shape[1]): 572 573 # Indices of non-zero pilots within the OFDM symbol 574 pilot_ind_ofdm = np.where(np.abs(z[a][i]))[0] 575 576 # If OFDM symbol contains only one non-zero pilot 577 if len(pilot_ind_ofdm)==1: 578 # Set the indices of the first and second pilot to the same 579 # value for all REs of the OFDM symbol 580 x_0_freq[a][i] = pilot_ind_ofdm[0] 581 x_1_freq[a][i] = pilot_ind_ofdm[0] 582 y_0_freq_ind[a,i] = pilot_ind[pilot_count] 583 y_1_freq_ind[a,i] = pilot_ind[pilot_count] 584 585 # If OFDM symbol contains two or more pilots 586 elif len(pilot_ind_ofdm)>=2: 587 x0 = 0 588 x1 = 1 589 590 # Go through all resource elements of this OFDM symbol 591 for j in range(x_0_freq.shape[2]): 592 x_0_freq[a,i,j] = pilot_ind_ofdm[x0] 593 x_1_freq[a,i,j] = pilot_ind_ofdm[x1] 594 y_0_freq_ind[a,i,j] = pilot_ind[pilot_count + x0] 595 y_1_freq_ind[a,i,j] = pilot_ind[pilot_count + x1] 596 if j==pilot_ind_ofdm[x1] and x1<len(pilot_ind_ofdm)-1: 597 x0 = x1 598 x1 += 1 599 600 pilot_count += len(pilot_ind_ofdm) 601 602 x_0_freq = np.reshape(x_0_freq, mask_shape) 603 x_1_freq = np.reshape(x_1_freq, mask_shape) 604 x_0_freq = expand_to_rank(x_0_freq, 7, axis=0) 605 x_1_freq = expand_to_rank(x_1_freq, 7, axis=0) 606 self._x_0_freq = tf.cast(x_0_freq, pilots.dtype) 607 self._x_1_freq = tf.cast(x_1_freq, pilots.dtype) 608 609 # We add +1 here to shift all indices as the input will be padded 610 # at the beginning with 0, (i.e., the dummy index -1 will become 0). 611 self._y_0_freq_ind = np.reshape(y_0_freq_ind, mask_shape)+1 612 self._y_1_freq_ind = np.reshape(y_1_freq_ind, mask_shape)+1 613 614 ## 615 ## Time-domain interpolation 616 ## 617 self._x_time = tf.expand_dims(tf.range(0, mask.shape[-2]), -1) 618 self._x_time = tf.cast(expand_to_rank(self._x_time, 7, axis=0), 619 dtype=pilots.dtype) 620 621 # Indices used to gather estimates 622 self._perm_fwd_time = tf.roll(tf.range(7), -3, 0) 623 624 y_0_time_ind = np.zeros(z.shape[:2], np.int32) # Gather indices 625 y_1_time_ind = np.zeros(z.shape[:2], np.int32) # Gather indices 626 627 # For each stream 628 for a in range(z.shape[0]): 629 630 # Indices of OFDM symbols for which channel estimates were computed 631 ofdm_ind = np.where(np.sum(np.abs(z[a]), axis=-1))[0] 632 633 # Only one OFDM symbol with pilots 634 if len(ofdm_ind)==1: 635 y_0_time_ind[a] = ofdm_ind[0] 636 y_1_time_ind[a] = ofdm_ind[0] 637 638 # Two or more OFDM symbols with pilots 639 elif len(ofdm_ind)>=2: 640 x0 = 0 641 x1 = 1 642 for i in range(z.shape[1]): 643 y_0_time_ind[a,i] = ofdm_ind[x0] 644 y_1_time_ind[a,i] = ofdm_ind[x1] 645 if i==ofdm_ind[x1] and x1<len(ofdm_ind)-1: 646 x0 = x1 647 x1 += 1 648 649 self._y_0_time_ind = np.reshape(y_0_time_ind, mask_shape[:-1]) 650 self._y_1_time_ind = np.reshape(y_1_time_ind, mask_shape[:-1]) 651 652 self._x_0_time = expand_to_rank(tf.expand_dims(self._y_0_time_ind, -1), 653 7, axis=0) 654 self._x_0_time = tf.cast(self._x_0_time, dtype=pilots.dtype) 655 self._x_1_time = expand_to_rank(tf.expand_dims(self._y_1_time_ind, -1), 656 7, axis=0) 657 self._x_1_time = tf.cast(self._x_1_time, dtype=pilots.dtype) 658 659 # 660 # Other precomputed values 661 # 662 # Undo permutation of batch_dims for gather 663 self._perm_bwd = tf.roll(tf.range(7), 3, 0) 664 665 # Padding for the inputs 666 pad = np.zeros([6, 2], np.int32) 667 pad[-1, 0] = 1 668 self._pad = pad 669 670 # Number of ofdm symbols carrying at least one pilot. 671 # Used for time-averaging (optional) 672 n = np.sum(np.abs(np.reshape(z, mask_shape)), axis=-1, keepdims=True) 673 n = np.sum(n>0, axis=-2, keepdims=True) 674 self._num_pilot_ofdm_symbols = expand_to_rank(n, 7, axis=0) 675 676 677 def _interpolate_1d(self, inputs, x, x0, x1, y0_ind, y1_ind): 678 # Gather the right values for y0 and y1 679 y0 = tf.gather(inputs, y0_ind, axis=2, batch_dims=2) 680 y1 = tf.gather(inputs, y1_ind, axis=2, batch_dims=2) 681 682 # Undo the permutation of the inputs 683 y0 = tf.transpose(y0, self._perm_bwd) 684 y1 = tf.transpose(y1, self._perm_bwd) 685 686 # Compute linear interpolation 687 slope = tf.math.divide_no_nan(y1-y0, tf.cast(x1-x0, dtype=y0.dtype)) 688 return tf.cast(x-x0, dtype=y0.dtype)*slope + y0 689 690 def _interpolate(self, inputs): 691 # 692 # Prepare inputs 693 # 694 # inputs has shape: 695 # [k, l, m, num_tx, num_streams_per_tx, num_pilots] 696 697 # Pad the inputs with a leading 0. 698 # All undefined channel estimates will get this value. 699 inputs = tf.pad(inputs, self._pad, constant_values=0) 700 701 # Transpose inputs to bring batch_dims for gather last. New shape: 702 # [num_tx, num_streams_per_tx, 1+num_pilots, k, l, m] 703 inputs = tf.transpose(inputs, self._perm_fwd_freq) 704 705 # 706 # Frequency-domain interpolation 707 # 708 # h_hat_freq has shape: 709 # [k, l, m, num_tx, num_streams_per_tx, num_ofdm_symbols,... 710 # ...num_effective_subcarriers] 711 h_hat_freq = self._interpolate_1d(inputs, 712 self._x_freq, 713 self._x_0_freq, 714 self._x_1_freq, 715 self._y_0_freq_ind, 716 self._y_1_freq_ind) 717 # 718 # Time-domain interpolation 719 # 720 721 # Time-domain averaging (optional) 722 if self._time_avg: 723 num_ofdm_symbols = h_hat_freq.shape[-2] 724 h_hat_freq = tf.reduce_sum(h_hat_freq, axis=-2, keepdims=True) 725 h_hat_freq /= tf.cast(self._num_pilot_ofdm_symbols,h_hat_freq.dtype) 726 h_hat_freq = tf.repeat(h_hat_freq, [num_ofdm_symbols], axis=-2) 727 728 # Transpose h_hat_freq to bring batch_dims for gather last. New shape: 729 # [num_tx, num_streams_per_tx, num_ofdm_symbols,... 730 # ...num_effective_subcarriers, k, l, m] 731 h_hat_time = tf.transpose(h_hat_freq, self._perm_fwd_time) 732 733 # h_hat_time has shape: 734 # [k, l, m, num_tx, num_streams_per_tx, num_ofdm_symbols,... 735 # ...num_effective_subcarriers] 736 h_hat_time = self._interpolate_1d(h_hat_time, 737 self._x_time, 738 self._x_0_time, 739 self._x_1_time, 740 self._y_0_time_ind, 741 self._y_1_time_ind) 742 743 return h_hat_time 744 745 def __call__(self, h_hat, err_var): 746 747 h_hat = self._interpolate(h_hat) 748 749 # the interpolator requires complex-valued inputs 750 err_var = tf.cast(err_var, tf.complex64) 751 err_var = self._interpolate(err_var) 752 err_var = tf.math.real(err_var) 753 754 return h_hat, err_var 755 756 757 class LMMSEInterpolator1D: 758 # pylint: disable=line-too-long 759 r"""LMMSEInterpolator1D(pilot_mask, cov_mat) 760 761 This class performs the linear interpolation across the inner dimension of the input ``h_hat``. 762 763 The two inner dimensions of the input ``h_hat`` form a matrix :math:`\hat{\mathbf{H}} \in \mathbb{C}^{N \times M}`. 764 LMMSE interpolation is performed across the inner dimension as follows: 765 766 .. math:: 767 \tilde{\mathbf{h}}_n = \mathbf{A}_n \hat{\mathbf{h}}_n 768 769 where :math:`1 \leq n \leq N` and :math:`\hat{\mathbf{h}}_n` is 770 the :math:`n^{\text{th}}` (transposed) row of :math:`\hat{\mathbf{H}}`. 771 :math:`\mathbf{A}_n` is the :math:`M \times M` interpolation LMMSE matrix: 772 773 .. math:: 774 \mathbf{A}_n = \mathbf{R} \mathbf{\Pi}_n \left( \mathbf{\Pi}_n^\intercal \mathbf{R} \mathbf{\Pi}_n + \tilde{\mathbf{\Sigma}}_n \right)^{-1} \mathbf{\Pi}_n^\intercal. 775 776 where :math:`\mathbf{R}` is the :math:`M \times M` covariance matrix across the inner dimension of the quantity which is estimated, 777 :math:`\mathbf{\Pi}_n` the :math:`M \times K_n` matrix that spreads :math:`K_n` 778 values to a vector of size :math:`M` according to the ``pilot_mask`` for the :math:`n^{\text{th}}` row, 779 and :math:`\tilde{\mathbf{\Sigma}}_n \in \mathbb{R}^{K_n \times K_n}` is the regularized channel estimation error covariance. 780 The :math:`i^{\text{th}}`` diagonal element of :math:`\tilde{\mathbf{\Sigma}}_n` is such that: 781 782 .. math:: 783 784 \left[ \tilde{\mathbf{\Sigma}}_n \right]_{i,i} = \text{max} \left\{ \right\} 785 786 built from ``err_var`` and assumed to be diagonal. 787 788 The returned channel estimates are 789 790 .. math:: 791 \begin{bmatrix} 792 {\tilde{\mathbf{h}}_1}^\intercal\\ 793 \vdots\\ 794 {\tilde{\mathbf{h}}_N}^\intercal 795 \end{bmatrix}. 796 797 The returned channel estimation error variances are the diaginal coefficients of 798 799 .. math:: 800 \text{diag} \left( \mathbf{R} - \mathbf{A}_n \mathbf{\Xi}_n \mathbf{R} \right), 1 \leq n \leq N 801 802 where :math:`\mathbf{\Xi}_n` is the diagonal matrix of size :math:`M \times M` that zeros the 803 columns corresponding to rows not carrying any pilots. 804 Note that interpolation is not performed for rows not carrying any pilots. 805 806 **Remark**: The interpolation matrix differs across rows as different 807 rows may carry pilots on different elements and/or have different 808 estimation error variances. 809 810 Parameters 811 ---------- 812 pilot_mask : [:math:`N`, :math:`M`] : int 813 Mask indicating the allocation of resource elements. 814 0 : Data, 815 1 : Pilot, 816 2 : Not used, 817 818 cov_mat : [:math:`M`, :math:`M`], tf.complex 819 Covariance matrix of the channel across the inner dimension. 820 821 last_step : bool 822 Set to `True` if this is the last interpolation step. 823 Otherwise, set to `False`. 824 If `True`, the the output is scaled to ensure its variance is as expected 825 by the following interpolation step. 826 827 Input 828 ----- 829 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, :math:`N`, :math:`M`], tf.complex 830 Channel estimates. 831 832 err_var : [batch_size, num_rx, num_rx_ant, num_tx, :math:`N`, :math:`M`], tf.complex 833 Channel estimation error variances. 834 835 Output 836 ------ 837 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, :math:`N`, :math:`M`], tf.complex 838 Channel estimates interpolated across the inner dimension. 839 840 err_var : Same shape as ``h_hat``, tf.float 841 The channel estimation error variances of the interpolated channel estimates. 842 """ 843 844 def __init__(self, pilot_mask, cov_mat, last_step): 845 846 self._cdtype = cov_mat.dtype 847 assert self._cdtype in (tf.complex64, tf.complex128),\ 848 "`cov_mat` dtype must be one of tf.complex64 or tf.complex128" 849 self._rdtype = self._cdtype.real_dtype 850 self._rzero = tf.constant(0.0, self._rdtype) 851 852 # Interpolation is performed along the inner dimension of 853 # the resource grid, which may be either the subcarriers 854 # or the OFDM symbols dimension. 855 # This dimension is referred to as the inner dimension. 856 # The other dimension of the resource grid is referred to 857 # as the outer dimension. 858 859 # Size of the inner dimension. 860 inner_dim_size = tf.shape(pilot_mask)[-1] 861 self._inner_dim_size = inner_dim_size 862 863 # Size of the outer dimension. 864 outer_dim_size = tf.shape(pilot_mask)[-2] 865 self._outer_dim_size = outer_dim_size 866 867 self._cov_mat = cov_mat 868 self._last_step = last_step 869 870 # Computation of the interpolation matrix is done solving the 871 # least-square problem: 872 # 873 # X = min_Z |AZ - B|_F^2 874 # 875 # where A = (\Pi_T R \Pi + S) and 876 # B = R \Pi 877 # where R is the channel covariance matrix, S the error 878 # diagonal covariance matrix, and \Pi the matrix that spreads the pilots 879 # according to the pilot pattern along the inner axis. 880 881 # Extracting the locations of pilots from the pilot mask 882 num_tx = tf.shape(pilot_mask)[0] 883 num_streams_per_tx = tf.shape(pilot_mask)[1] 884 885 # List of indices of pilots in the inner dimension for every 886 # transmit antenna, stream, and outer dimension element. 887 pilot_indices = [] 888 # Maximum number of pilots carried by an inner dimension. 889 max_num_pil = 0 890 # Indices used to add the error variance to the diagonal 891 # elements of the covariance matrix restricted 892 # to the elements carrying pilots. 893 # These matrices are computed below. 894 add_err_var_indices = np.zeros([num_tx, num_streams_per_tx, 895 outer_dim_size, inner_dim_size, 5], int) 896 for tx in range(num_tx): 897 pilot_indices.append([]) 898 for st in range(num_streams_per_tx): 899 pilot_indices[-1].append([]) 900 for oi in range(outer_dim_size): 901 pilot_indices[-1][-1].append([]) 902 num_pil = 0 # Number of pilots on this outer dim 903 for ii in range(inner_dim_size): 904 # Check if this RE is carrying a pilot 905 # for this stream 906 if pilot_mask[tx,st,oi,ii] == 0: 907 continue 908 if pilot_mask[tx,st,oi,ii] == 1: 909 pilot_indices[tx][st][oi].append(ii) 910 indices = [tx, st, oi, num_pil, num_pil] 911 add_err_var_indices[tx, st, oi, ii] = indices 912 num_pil += 1 913 max_num_pil = max(max_num_pil, num_pil) 914 # [num_tx, num_streams_per_tx, outer_dim_size, inner_dim_size, 5] 915 self._add_err_var_indices = tf.cast(add_err_var_indices, tf.int32) 916 917 # Different subcarriers/symbols may carry a different number of pilots. 918 # To handle such cases, we create a tensor of square matrices of 919 # size the maximum number of pilots carried by an inner dimension 920 # and zero-padding is used to handle axes with less pilots than the 921 # maximum value. The obtained structure is: 922 # 923 # |B 0| 924 # |0 0| 925 # 926 pil_cov_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, 927 max_num_pil, max_num_pil], complex) 928 for tx,st,oi in itertools.product(range(num_tx), 929 range(num_streams_per_tx), 930 range(outer_dim_size)): 931 pil_ind = pilot_indices[tx][st][oi] 932 num_pil = len(pil_ind) 933 tmp = np.take(cov_mat, pil_ind, axis=0) 934 pil_cov_mat_ = np.take(tmp, pil_ind, axis=1) 935 pil_cov_mat[tx,st,oi,:num_pil,:num_pil] = pil_cov_mat_ 936 # [num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil] 937 self._pil_cov_mat = tf.constant(pil_cov_mat, self._cdtype) 938 939 # Pre-compute the covariance matrix with only the columns corresponding 940 # to pilots. 941 b_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, 942 max_num_pil, inner_dim_size], complex) 943 for tx,st,oi in itertools.product(range(num_tx), 944 range(num_streams_per_tx), 945 range(outer_dim_size)): 946 pil_ind = pilot_indices[tx][st][oi] 947 num_pil = len(pil_ind) 948 b_mat_ = np.take(cov_mat, pil_ind, axis=0) 949 b_mat[tx,st,oi,:num_pil,:] = b_mat_ 950 self._b_mat = tf.constant(b_mat, self._cdtype) 951 952 # Indices used to fill with zeros the columns of the interpolation 953 # matrix not corresponding to zeros. 954 # The results is a matrix of size inner_dim_size x inner_dim_size 955 # where rows and columns not correspondong to pilots are set to zero. 956 pil_loc = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, 957 inner_dim_size, max_num_pil, 5], dtype=int) 958 for tx,st,oi,p,ii in itertools.product(range(num_tx), 959 range(num_streams_per_tx), 960 range(outer_dim_size), 961 range(max_num_pil), 962 range(inner_dim_size)): 963 if p >= len(pilot_indices[tx][st][oi]): 964 # An extra dummy subcarrier is added to push there padding 965 # identity matrix 966 pil_loc[tx, st, oi, ii, p] = [tx, st, oi, 967 inner_dim_size, 968 inner_dim_size] 969 else: 970 pil_loc[tx, st, oi, ii, p] = [tx, st, oi, 971 ii, 972 pilot_indices[tx][st][oi][p]] 973 self._pil_loc = tf.cast(pil_loc, tf.int32) 974 975 # Covariance matrix for each stream with only the row corresponding 976 # to a pilot carrying RE not set to 0. 977 # This is required to compute the estimation error variances. 978 err_var_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, 979 inner_dim_size, inner_dim_size], complex) 980 for tx,st,oi in itertools.product(range(num_tx), 981 range(num_streams_per_tx), 982 range(outer_dim_size)): 983 pil_ind = pilot_indices[tx][st][oi] 984 mask = np.zeros([inner_dim_size], complex) 985 mask[pil_ind] = 1.0 986 mask = np.expand_dims(mask, axis=1) 987 err_var_mat[tx,st,oi] = cov_mat*mask 988 self._err_var_mat = tf.constant(err_var_mat, self._cdtype) 989 990 def __call__(self, h_hat, err_var): 991 992 # h_hat : [batch_size, num_rx, num_rx_ant, num_tx, 993 # num_streams_per_tx, outer_dim_size, inner_dim_size] 994 # err_var : [batch_size, num_rx, num_rx_ant, num_tx, 995 # num_streams_per_tx, outer_dim_size, inner_dim_size] 996 997 batch_size = tf.shape(h_hat)[0] 998 num_rx = tf.shape(h_hat)[1] 999 num_rx_ant = tf.shape(h_hat)[2] 1000 num_tx = tf.shape(h_hat)[3] 1001 num_tx_stream = tf.shape(h_hat)[4] 1002 outer_dim_size = self._outer_dim_size 1003 inner_dim_size = self._inner_dim_size 1004 1005 ##################################### 1006 # Compute the interpolation matrix 1007 ##################################### 1008 1009 # Computation of the interpolation matrix is done solving the 1010 # least-square problem: 1011 # 1012 # X = min_Z |AZ - B|_F^2 1013 # 1014 # where A = (\Pi_T R \Pi + S) and 1015 # B = R \Pi 1016 # where R is the channel covariance matrix, S the error 1017 # diagonal covariance matrix, and \Pi the matrix that spreads the pilots 1018 # according to the pilot pattern along the inner axis. 1019 1020 # 1021 # Computing A 1022 # 1023 1024 # Covariance matrices restricted to pilot locations 1025 # [num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil] 1026 pil_cov_mat = self._pil_cov_mat 1027 1028 # Adding batch, receive, and receive antennas dimensions to the 1029 # covariance matrices restricted to pilot locations and to the 1030 # regularization values 1031 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1032 # outer_dim_size, max_num_pil, max_num_pil] 1033 pil_cov_mat = expand_to_rank(pil_cov_mat, 8, 0) 1034 pil_cov_mat = tf.tile(pil_cov_mat, [batch_size, num_rx, num_rx_ant, 1035 1, 1, 1, 1, 1]) 1036 1037 # Adding the noise variance to the covariance matrices restricted to 1038 # pilots 1039 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1040 # outer_dim_size, max_num_pil, max_num_pil] 1041 pil_cov_mat_ = tf.transpose(pil_cov_mat, [3, 4, 5, 6, 7, 0, 1, 2]) 1042 err_var_ = tf.complex(err_var, self._rzero) 1043 err_var_ = tf.transpose(err_var_, [3, 4, 5, 6, 0, 1, 2]) 1044 a_mat = tf.tensor_scatter_nd_add(pil_cov_mat_, 1045 self._add_err_var_indices, err_var_) 1046 a_mat = tf.transpose(a_mat, [5, 6, 7, 0, 1, 2, 3, 4]) 1047 1048 # 1049 # Computing B 1050 # 1051 1052 # B is pre-computed as it only depend on the channel covariance and 1053 # pilot pattern. 1054 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1055 # outer_dim_size, max_num_pil, inner_dim_size] 1056 b_mat = self._b_mat 1057 b_mat = expand_to_rank(b_mat, 8, 0) 1058 b_mat = tf.tile(b_mat, [batch_size, num_rx, num_rx_ant, 1059 1, 1, 1, 1, 1]) 1060 1061 # 1062 # Computing the interpolation matrix 1063 # 1064 1065 # Using lstsq to compute the columns of the interpolation matrix 1066 # corresponding to pilots. 1067 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1068 # outer_dim_size, inner_dim_size, max_num_pil] 1069 ext_mat = tf.linalg.lstsq(a_mat, b_mat, fast=False) 1070 ext_mat = tf.transpose(ext_mat, [0,1,2,3,4,5,7,6], conjugate=True) 1071 1072 # Filling with zeros the columns not corresponding to pilots. 1073 # An extra dummy outer dim is added to scatter there the coefficients 1074 # of the identity matrix used for padding. 1075 # This dummy dim is then removed. 1076 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1077 # outer_dim_size, inner_dim_size, inner_dim_size] 1078 ext_mat = tf.transpose(ext_mat, [3, 4, 5, 6, 7, 0, 1, 2]) 1079 ext_mat = tf.scatter_nd(self._pil_loc, ext_mat, 1080 [num_tx, num_tx_stream, 1081 outer_dim_size, 1082 inner_dim_size+1, 1083 inner_dim_size+1, 1084 batch_size, num_rx, num_rx_ant]) 1085 ext_mat = tf.transpose(ext_mat, [5, 6, 7, 0, 1, 2, 3, 4]) 1086 ext_mat = ext_mat[...,:inner_dim_size,:inner_dim_size] 1087 1088 ################################################ 1089 # Apply interpolation over the inner dimension 1090 ################################################ 1091 1092 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1093 # outer_dim_size, inner_dim_size] 1094 h_hat = tf.expand_dims(h_hat, axis=-1) 1095 h_hat = tf.matmul(ext_mat, h_hat) 1096 h_hat = tf.squeeze(h_hat, axis=-1) 1097 1098 ############################## 1099 # Compute the error variances 1100 ############################## 1101 1102 # Keep track of the previous estimation error variances for later use 1103 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1104 # outer_dim_size, inner_dim_size] 1105 err_var_old = err_var 1106 1107 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1108 # outer_dim_size, inner_dim_size] 1109 cov_mat = expand_to_rank(self._cov_mat, 8, 0) 1110 err_var = tf.linalg.diag_part(cov_mat) 1111 err_var_mat = expand_to_rank(self._err_var_mat, 8, 0) 1112 err_var_mat = tf.transpose(err_var_mat, [0, 1, 2, 3, 4, 5, 7, 6]) 1113 err_var = err_var - tf.reduce_sum(ext_mat*err_var_mat, axis=-1) 1114 err_var = tf.math.real(err_var) 1115 err_var = tf.maximum(err_var, self._rzero) 1116 1117 ##################################### 1118 # If this is *not* the last 1119 # interpolation step, scales the 1120 # input `h_hat` to ensure 1121 # it has the variance expected by the 1122 # next interpolation step. 1123 # 1124 # The error variance also `err_var` 1125 # is updated accordingly. 1126 ##################################### 1127 if not self._last_step: 1128 # 1129 # Variance of h_hat 1130 # 1131 # Conjugate transpose of LMMSE matrix 1132 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1133 # outer_dim_size, inner_dim_size, inner_dim_size] 1134 ext_mat_h = tf.transpose(ext_mat, [0, 1, 2, 3, 4, 5, 7, 6], 1135 conjugate=True) 1136 # First part of the estimate covariance 1137 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1138 # outer_dim_size, inner_dim_size, inner_dim_size] 1139 h_hat_var_1 = tf.matmul(cov_mat, ext_mat_h) 1140 h_hat_var_1 = tf.transpose(h_hat_var_1, [0, 1, 2, 3, 4, 5, 7, 6]) 1141 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1142 # outer_dim_size, inner_dim_size] 1143 h_hat_var_1 = tf.reduce_sum(ext_mat*h_hat_var_1, axis=-1) 1144 # Second part of the estimate covariance 1145 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1146 # outer_dim_size, inner_dim_size] 1147 err_var_old_c = tf.complex(err_var_old, self._rzero) 1148 err_var_old_c = tf.expand_dims(err_var_old_c, axis=-1) 1149 h_hat_var_2 = err_var_old_c*ext_mat_h 1150 h_hat_var_2 = tf.transpose(h_hat_var_2, [0, 1, 2, 3, 4, 5, 7, 6]) 1151 h_hat_var_2 = tf.reduce_sum(ext_mat*h_hat_var_2, axis=-1) 1152 # Variance of h_hat 1153 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1154 # outer_dim_size, inner_dim_size] 1155 h_hat_var = h_hat_var_1 + h_hat_var_2 1156 # Scaling factor 1157 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1158 # outer_dim_size, inner_dim_size] 1159 err_var_c = tf.complex(err_var, self._rzero) 1160 h_var = tf.linalg.diag_part(cov_mat) 1161 s = tf.math.divide_no_nan(2.*h_var, h_hat_var + h_var - err_var_c) 1162 # Apply scaling to estimate 1163 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1164 # outer_dim_size, inner_dim_size] 1165 h_hat = s*h_hat 1166 # Updated variance 1167 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1168 # outer_dim_size, inner_dim_size] 1169 err_var = s*(s-1.)*h_hat_var + (1.-s)*h_var + s*err_var_c 1170 err_var = tf.math.real(err_var) 1171 err_var = tf.maximum(err_var, self._rzero) 1172 1173 return h_hat, err_var 1174 1175 class SpatialChannelFilter: 1176 # pylint: disable=line-too-long 1177 r"""SpatialChannelFilter(cov_mat, last_step) 1178 1179 Implements linear minimum mean square error (LMMSE) smoothing. 1180 1181 We consider the following model: 1182 1183 .. math:: 1184 1185 \mathbf{y} = \mathbf{h} + \mathbf{n} 1186 1187 where :math:`\mathbf{y}\in\mathbb{C}^{M}` is the received signal vector, 1188 :math:`\mathbf{h}\in\mathbb{C}^{M}` is the channel vector to be estimated 1189 with covariance matrix 1190 :math:`\mathbb{E}\left[ \mathbf{h} \mathbf{h}^{\mathsf{H}} \right] = \mathbf{R}`, 1191 and :math:`\mathbf{n}\in\mathbb{C}^M` is a zero-mean noise vector whose 1192 elements have variance :math:`N_0`. 1193 1194 The channel estimate :math:`\hat{\mathbf{h}}` is computed as 1195 1196 .. math:: 1197 1198 \hat{\mathbf{h}} &= \mathbf{A} \mathbf{y} 1199 1200 where 1201 1202 .. math:: 1203 1204 \mathbf{A} = \mathbf{R} \left( \mathbf{R} + N_0 \mathbf{I}_M \right)^{-1} 1205 1206 where :math:`\mathbf{I}_M` is the :math:`M \times M` identity matrix. 1207 The estimation error is: 1208 1209 .. math:: 1210 1211 \tilde{h} = \mathbf{h} - \hat{\mathbf{h}} 1212 1213 The error variances 1214 1215 .. math:: 1216 1217 \sigma^2_i = \mathbb{E}\left[\tilde{h}_i \tilde{h}_i^\star \right], 0 \leq i \leq M-1 1218 1219 are the diagonal elements of 1220 1221 .. math:: 1222 1223 \mathbb{E}\left[\mathbf{\tilde{h}} \mathbf{\tilde{h}}^{\mathsf{H}} \right] = \mathbf{R} - \mathbf{A}\mathbf{R}. 1224 1225 1226 Note 1227 ---- 1228 If you want to use this function in Graph mode with XLA, i.e., within 1229 a function that is decorated with ``@tf.function(jit_compile=True)``, 1230 you must set ``sionna.Config.xla_compat=true``. 1231 See :py:attr:`~sionna.Config.xla_compat`. 1232 1233 Parameters 1234 ---------- 1235 cov_mat : [num_rx_ant, num_rx_ant], tf.complex 1236 Spatial covariance matrix of the channel 1237 1238 last_step : bool 1239 Set to `True` if this is the last interpolation step. 1240 Otherwise, set to `False`. 1241 If `True`, the the output is scaled to ensure its variance is as expected 1242 by the following interpolation step. 1243 1244 Input 1245 ----- 1246 h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.complex 1247 Channel estimates. 1248 1249 err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.float 1250 Channel estimation error variances. 1251 1252 Output 1253 ------ 1254 h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.complex 1255 Channel estimates smoothed accross the spatial dimension 1256 1257 err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.float 1258 The channel estimation error variances of the smoothed channel estimates. 1259 """ 1260 1261 def __init__(self, cov_mat, last_step): 1262 self._rzero = tf.zeros((), cov_mat.dtype.real_dtype) 1263 self._cov_mat = cov_mat 1264 self._last_step = last_step 1265 1266 # Indices for adding a tensor of vectors [..., num_rx_ant] to the 1267 # diagonal of a tensor of matrices [..., num_rx_ant, num_rx_ant] 1268 num_rx_ant = cov_mat.shape[0] 1269 add_diag_indices = [[rxa, rxa] for rxa in range(num_rx_ant)] 1270 self._add_diag_indices = tf.cast(add_diag_indices, tf.int32) 1271 1272 def __call__(self, h_hat, err_var): 1273 # h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, 1274 # num_ofdm_symbols, num_subcarriers, num_rx_ant] 1275 # err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, 1276 # num_ofdm_symbols, num_subcarriers, num_rx_ant] 1277 1278 # [..., num_rx_ant] 1279 err_var = tf.complex(err_var, self._rzero) 1280 # Keep track of the previous estimation error variances for later use 1281 err_var_old = err_var 1282 1283 # [num_rx_ant, num_rx_ant] 1284 cov_mat = self._cov_mat 1285 cov_mat_t = tf.transpose(cov_mat) 1286 num_rx_ant = tf.shape(cov_mat)[0] 1287 1288 ########################################## 1289 # Compute LMMSE matrix 1290 ########################################## 1291 1292 # [..., num_rx_ant, num_rx_ant] 1293 cov_mat = expand_to_rank(cov_mat, tf.rank(err_var)+1, axis=0) 1294 1295 # Adding the error variances to the diagonal 1296 # [..., num_rx_ant, num_rx_ant] 1297 lmmse_mat = tf.broadcast_to(cov_mat, tf.concat([tf.shape(err_var), 1298 [num_rx_ant]], axis=0)) 1299 # [num_rx_ant, ...] 1300 err_var_ = tf.transpose(err_var, [6, 0, 1, 2, 3, 4, 5]) 1301 # [num_rx_ant, num_rx_ant, ...] 1302 lmmse_mat = tf.transpose(lmmse_mat, [6, 7, 0, 1, 2, 3, 4, 5]) 1303 lmmse_mat = tf.tensor_scatter_nd_add(lmmse_mat, 1304 self._add_diag_indices, err_var_) 1305 # [..., num_rx_ant, num_rx_ant] 1306 lmmse_mat = tf.transpose(lmmse_mat, [2, 3, 4, 5, 6, 7, 0, 1]) 1307 1308 # [..., num_rx_ant, num_rx_ant] 1309 lmmse_mat = matrix_inv(lmmse_mat) 1310 lmmse_mat = tf.matmul(cov_mat, lmmse_mat) 1311 1312 ########################################## 1313 # Apply smoothing 1314 ########################################## 1315 1316 # [..., num_rx_ant, 1] 1317 h_hat = tf.expand_dims(h_hat, axis=-1) 1318 # [..., num_rx_ant] 1319 h_hat = tf.squeeze(tf.matmul(lmmse_mat, h_hat), axis=-1) 1320 1321 ########################################## 1322 # Compute the estimation error variances 1323 ########################################## 1324 1325 # [..., num_rx_ant, num_rx_ant] 1326 cov_mat_t = expand_to_rank(cov_mat_t, tf.rank(lmmse_mat), axis=0) 1327 # [..., num_rx_ant] 1328 err_var = tf.reduce_sum(cov_mat_t*lmmse_mat, axis=-1) 1329 # [..., num_rx_ant] 1330 err_var = tf.linalg.diag_part(cov_mat) - err_var 1331 err_var = tf.math.real(err_var) 1332 err_var = tf.maximum(err_var, self._rzero) 1333 1334 ########################################## 1335 # If this is *not* the last 1336 # interpolation step, scales the 1337 # input `h_hat` to ensure 1338 # it has the variance expected by the 1339 # next interpolation step. 1340 # 1341 # The error variance also `err_var` 1342 # is updated accordingly. 1343 ########################################## 1344 if not self._last_step: 1345 # 1346 # Variance of h_hat 1347 # 1348 # Conjugate transpose of the LMMSE matrix 1349 # [..., num_rx_ant, num_rx_ant] 1350 lmmse_mat_h = tf.transpose(lmmse_mat, [0, 1, 2, 3, 4, 5, 7, 6], 1351 conjugate=True) 1352 # First part of the estimate covariance 1353 # [..., num_rx_ant, num_rx_ant] 1354 h_hat_var_1 = tf.matmul(cov_mat, lmmse_mat_h) 1355 h_hat_var_1 = tf.transpose(h_hat_var_1, [0, 1, 2, 3, 4, 5, 7, 6]) 1356 # [..., num_rx_ant] 1357 h_hat_var_1 = tf.reduce_sum(lmmse_mat*h_hat_var_1, axis=-1) 1358 # Second part of the estimate covariance 1359 # [..., num_rx_ant, 1] 1360 err_var_old = tf.expand_dims(err_var_old, axis=-1) 1361 # [..., num_rx_ant, num_rx_ant] 1362 h_hat_var_2 = err_var_old*lmmse_mat_h 1363 # [..., num_rx_ant, num_rx_ant] 1364 h_hat_var_2 = tf.transpose(h_hat_var_2, [0, 1, 2, 3, 4, 5, 7, 6]) 1365 # [..., num_rx_ant] 1366 h_hat_var_2 = tf.reduce_sum(lmmse_mat*h_hat_var_2, axis=-1) 1367 # Variance of h_hat 1368 # [..., num_rx_ant] 1369 h_hat_var = h_hat_var_1 + h_hat_var_2 1370 # Scaling factor 1371 # [..., num_rx_ant] 1372 err_var_c = tf.complex(err_var, self._rzero) 1373 h_var = tf.linalg.diag_part(cov_mat) 1374 s = tf.math.divide_no_nan(2.*h_var, h_hat_var + h_var - err_var_c) 1375 # Apply scaling to estimate 1376 # [..., num_rx_ant] 1377 h_hat = s*h_hat 1378 # Updated variance 1379 # [..., num_rx_ant] 1380 err_var = s*(s-1.)*h_hat_var + (1.-s)*h_var + s*err_var_c 1381 err_var = tf.math.real(err_var) 1382 err_var = tf.maximum(err_var, self._rzero) 1383 1384 return h_hat, err_var 1385 1386 1387 class LMMSEInterpolator(BaseChannelInterpolator): 1388 # pylint: disable=line-too-long 1389 r"""LMMSEInterpolator(pilot_pattern, cov_mat_time, cov_mat_freq, cov_mat_space=None, order='t-f') 1390 1391 LMMSE interpolation on a resource grid with optional spatial smoothing. 1392 1393 This class computes for each element of an OFDM resource grid 1394 a channel estimate and error variance 1395 through linear minimum mean square error (LMMSE) interpolation/smoothing. 1396 It is assumed that the measurements were taken at the nonzero positions 1397 of a :class:`~sionna.ofdm.PilotPattern`. 1398 1399 Depending on the value of ``order``, the interpolation is carried out 1400 accross time (t), i.e., OFDM symbols, frequency (f), i.e., subcarriers, 1401 and optionally space (s), i.e., receive antennas, in any desired order. 1402 1403 For simplicity, we describe the underlying algorithm assuming that interpolation 1404 across the sub-carriers is performed first, followed by interpolation across 1405 OFDM symbols, and finally by spatial smoothing across receive 1406 antennas. 1407 The algorithm is similar if interpolation and/or smoothing are performed in 1408 a different order. 1409 For clarity, antenna indices are omitted when describing frequency and time 1410 interpolation, as the same process is applied to all the antennas. 1411 1412 The input ``h_hat`` is first reshaped to a resource grid 1413 :math:`\hat{\mathbf{H}} \in \mathbb{C}^{N \times M}`, by scattering the channel 1414 estimates at pilot locations according to the ``pilot_pattern``. :math:`N` 1415 denotes the number of OFDM symbols and :math:`M` the number of sub-carriers. 1416 1417 The first pass consists in interpolating across the sub-carriers: 1418 1419 .. math:: 1420 \hat{\mathbf{h}}_n^{(1)} = \mathbf{A}_n \hat{\mathbf{h}}_n 1421 1422 where :math:`1 \leq n \leq N` is the OFDM symbol index and :math:`\hat{\mathbf{h}}_n` is 1423 the :math:`n^{\text{th}}` (transposed) row of :math:`\hat{\mathbf{H}}`. 1424 :math:`\mathbf{A}_n` is the :math:`M \times M` matrix such that: 1425 1426 .. math:: 1427 \mathbf{A}_n = \bar{\mathbf{A}}_n \mathbf{\Pi}_n^\intercal 1428 1429 where 1430 1431 .. math:: 1432 \bar{\mathbf{A}}_n = \underset{\mathbf{Z} \in \mathbb{C}^{M \times K_n}}{\text{argmin}} \left\lVert \mathbf{Z}\left( \mathbf{\Pi}_n^\intercal \mathbf{R^{(f)}} \mathbf{\Pi}_n + \mathbf{\Sigma}_n \right) - \mathbf{R^{(f)}} \mathbf{\Pi}_n \right\rVert_{\text{F}}^2 1433 1434 and :math:`\mathbf{R^{(f)}}` is the :math:`M \times M` channel frequency covariance matrix, 1435 :math:`\mathbf{\Pi}_n` the :math:`M \times K_n` matrix that spreads :math:`K_n` 1436 values to a vector of size :math:`M` according to the ``pilot_pattern`` for the :math:`n^{\text{th}}` OFDM symbol, 1437 and :math:`\mathbf{\Sigma}_n \in \mathbb{R}^{K_n \times K_n}` is the channel estimation error covariance built from 1438 ``err_var`` and assumed to be diagonal. 1439 Computation of :math:`\bar{\mathbf{A}}_n` is done using an algorithm based on complete orthogonal decomposition. 1440 This is done to avoid matrix inversion for badly conditioned covariance matrices. 1441 1442 The channel estimation error variances after the first interpolation pass are computed as 1443 1444 .. math:: 1445 \mathbf{\Sigma}^{(1)}_n = \text{diag} \left( \mathbf{R^{(f)}} - \mathbf{A}_n \mathbf{\Xi}_n \mathbf{R^{(f)}} \right) 1446 1447 where :math:`\mathbf{\Xi}_n` is the diagonal matrix of size :math:`M \times M` that zeros the 1448 columns corresponding to sub-carriers not carrying any pilots. 1449 Note that interpolation is not performed for OFDM symbols which do not carry pilots. 1450 1451 **Remark**: The interpolation matrix differs across OFDM symbols as different 1452 OFDM symbols may carry pilots on different sub-carriers and/or have different 1453 estimation error variances. 1454 1455 Scaling of the estimates is then performed to ensure that their 1456 variances match the ones expected by the next interpolation step, and the error variances are updated accordingly: 1457 1458 .. math:: 1459 \begin{align} 1460 \left[\hat{\mathbf{h}}_n^{(2)}\right]_m &= s_{n,m} \left[\hat{\mathbf{h}}_n^{(1)}\right]_m\\ 1461 \left[\mathbf{\Sigma}^{(2)}_n\right]_{m,m} &= s_{n,m}\left( s_{n,m}-1 \right) \left[\hat{\mathbf{\Sigma}}^{(1)}_n\right]_{m,m} + \left( 1 - s_{n,m} \right) \left[\mathbf{R^{(f)}}\right]_{m,m} + s_{n,m} \left[\mathbf{\Sigma}^{(1)}_n\right]_{m,m} 1462 \end{align} 1463 1464 where the scaling factor :math:`s_{n,m}` is such that: 1465 1466 1467 .. math:: 1468 \mathbb{E} \left\{ \left\lvert s_{n,m} \left[\hat{\mathbf{h}}_n^{(1)}\right]_m \right\rvert^2 \right\} = \left[\mathbf{R^{(f)}}\right]_{m,m} + \mathbb{E} \left\{ \left\lvert s_{n,m} \left[\hat{\mathbf{h}}^{(1)}_n\right]_m - \left[\mathbf{h}_n\right]_m \right\rvert^2 \right\} 1469 1470 which leads to: 1471 1472 .. math:: 1473 \begin{align} 1474 s_{n,m} &= \frac{2 \left[\mathbf{R^{(f)}}\right]_{m,m}}{\left[\mathbf{R^{(f)}}\right]_{m,m} - \left[\mathbf{\Sigma}^{(1)}_n\right]_{m,m} + \left[\hat{\mathbf{\Sigma}}^{(1)}_n\right]_{m,m}}\\ 1475 \hat{\mathbf{\Sigma}}^{(1)}_n &= \mathbf{A}_n \mathbf{R^{(f)}} \mathbf{A}_n^{\mathrm{H}}. 1476 \end{align} 1477 1478 The second pass consists in interpolating across the OFDM symbols: 1479 1480 .. math:: 1481 \hat{\mathbf{h}}_m^{(3)} = \mathbf{B}_m \tilde{\mathbf{h}}^{(2)}_m 1482 1483 where :math:`1 \leq m \leq M` is the sub-carrier index and :math:`\tilde{\mathbf{h}}^{(2)}_m` is 1484 the :math:`m^{\text{th}}` column of 1485 1486 .. math:: 1487 \hat{\mathbf{H}}^{(2)} = \begin{bmatrix} 1488 {\hat{\mathbf{h}}_1^{(2)}}^\intercal\\ 1489 \vdots\\ 1490 {\hat{\mathbf{h}}_N^{(2)}}^\intercal 1491 \end{bmatrix} 1492 1493 and :math:`\mathbf{B}_m` is the :math:`N \times N` interpolation LMMSE matrix: 1494 1495 .. math:: 1496 \mathbf{B}_m = \bar{\mathbf{B}}_m \tilde{\mathbf{\Pi}}_m^\intercal 1497 1498 where 1499 1500 .. math:: 1501 \bar{\mathbf{B}}_m = \underset{\mathbf{Z} \in \mathbb{C}^{N \times L_m}}{\text{argmin}} \left\lVert \mathbf{Z} \left( \tilde{\mathbf{\Pi}}_m^\intercal \mathbf{R^{(t)}}\tilde{\mathbf{\Pi}}_m + \tilde{\mathbf{\Sigma}}^{(2)}_m \right) - \mathbf{R^{(t)}}\tilde{\mathbf{\Pi}}_m \right\rVert_{\text{F}}^2 1502 1503 where :math:`\mathbf{R^{(t)}}` is the :math:`N \times N` channel time covariance matrix, 1504 :math:`\tilde{\mathbf{\Pi}}_m` the :math:`N \times L_m` matrix that spreads :math:`L_m` 1505 values to a vector of size :math:`N` according to the ``pilot_pattern`` for the :math:`m^{\text{th}}` sub-carrier, 1506 and :math:`\tilde{\mathbf{\Sigma}}^{(2)}_m \in \mathbb{R}^{L_m \times L_m}` is the diagonal matrix of channel estimation error variances 1507 built by gathering the error variances from (:math:`\mathbf{\Sigma}^{(2)}_1,\dots,\mathbf{\Sigma}^{(2)}_N`) corresponding 1508 to resource elements carried by the :math:`m^{\text{th}}` sub-carrier. 1509 Computation of :math:`\bar{\mathbf{B}}_m` is done using an algorithm based on complete orthogonal decomposition. 1510 This is done to avoid matrix inversion for badly conditioned covariance matrices. 1511 1512 The resulting channel estimate for the resource grid is 1513 1514 .. math:: 1515 \hat{\mathbf{H}}^{(3)} = \left[ \hat{\mathbf{h}}_1^{(3)} \dots \hat{\mathbf{h}}_M^{(3)} \right] 1516 1517 The resulting channel estimation error variances are the diagonal coefficients of the matrices 1518 1519 .. math:: 1520 \mathbf{\Sigma}^{(3)}_m = \mathbf{R^{(t)}} - \mathbf{B}_m \tilde{\mathbf{\Xi}}_m \mathbf{R^{(t)}}, 1 \leq m \leq M 1521 1522 where :math:`\tilde{\mathbf{\Xi}}_m` is the diagonal matrix of size :math:`N \times N` that zeros the 1523 columns corresponding to OFDM symbols not carrying any pilots. 1524 1525 **Remark**: The interpolation matrix differs across sub-carriers as different 1526 sub-carriers may have different estimation error variances computed by the first 1527 pass. 1528 However, all sub-carriers carry at least one channel estimate as a result of 1529 the first pass, ensuring that a channel estimate is computed for all the resource 1530 elements after the second pass. 1531 1532 **Remark:** LMMSE interpolation requires knowledge of the time and frequency 1533 covariance matrices of the channel. The notebook `OFDM MIMO Channel Estimation and Detection <../examples/OFDM_MIMO_Detection.ipynb>`_ shows how to estimate 1534 such matrices for arbitrary channel models. 1535 Moreover, the functions :func:`~sionna.ofdm.tdl_time_cov_mat` 1536 and :func:`~sionna.ofdm.tdl_freq_cov_mat` compute the expected time and frequency 1537 covariance matrices, respectively, for the :class:`~sionna.channel.tr38901.TDL` channel models. 1538 1539 Scaling of the estimates is then performed to ensure that their 1540 variances match the ones expected by the next smoothing step, and the 1541 error variances are updated accordingly: 1542 1543 .. math:: 1544 \begin{align} 1545 \left[\hat{\mathbf{h}}_m^{(4)}\right]_n &= \gamma_{m,n} \left[\hat{\mathbf{h}}_m^{(3)}\right]_n\\ 1546 \left[\mathbf{\Sigma}^{(4)}_m\right]_{n,n} &= \gamma_{m,n}\left( \gamma_{m,n}-1 \right) \left[\hat{\mathbf{\Sigma}}^{(3)}_m\right]_{n,n} + \left( 1 - \gamma_{m,n} \right) \left[\mathbf{R^{(t)}}\right]_{n,n} + \gamma_{m,n} \left[\mathbf{\Sigma}^{(3)}_n\right]_{m,m} 1547 \end{align} 1548 1549 where: 1550 1551 .. math:: 1552 \begin{align} 1553 \gamma_{m,n} &= \frac{2 \left[\mathbf{R^{(t)}}\right]_{n,n}}{\left[\mathbf{R^{(t)}}\right]_{n,n} - \left[\mathbf{\Sigma}^{(3)}_m\right]_{n,n} + \left[\hat{\mathbf{\Sigma}}^{(3)}_n\right]_{m,m}}\\ 1554 \hat{\mathbf{\Sigma}}^{(3)}_m &= \mathbf{B}_m \mathbf{R^{(t)}} \mathbf{B}_m^{\mathrm{H}} 1555 \end{align} 1556 1557 Finally, a spatial smoothing step is applied to every resource element carrying 1558 a channel estimate. 1559 For clarity, we drop the resource element indexing :math:`(n,m)`. 1560 We denote by :math:`L` the number of receive antennas, and by 1561 :math:`\mathbf{R^{(s)}}\in\mathbb{C}^{L \times L}` the spatial covariance matrix. 1562 1563 LMMSE spatial smoothing consists in the following computations: 1564 1565 .. math:: 1566 \hat{\mathbf{h}}^{(5)} = \mathbf{C} \hat{\mathbf{h}}^{(4)} 1567 1568 where 1569 1570 .. math:: 1571 \mathbf{C} = \mathbf{R^{(s)}} \left( \mathbf{R^{(s)}} + \mathbf{\Sigma}^{(4)} \right)^{-1}. 1572 1573 The estimation error variances are the digonal coefficients of 1574 1575 .. math:: 1576 \mathbf{\Sigma}^{(5)} = \mathbf{R^{(s)}} - \mathbf{C}\mathbf{R^{(s)}} 1577 1578 The smoothed channel estimate :math:`\hat{\mathbf{h}}^{(5)}` and corresponding 1579 error variances :math:`\text{diag}\left( \mathbf{\Sigma}^{(5)} \right)` are 1580 returned for every resource element :math:`(m,n)`. 1581 1582 **Remark:** No scaling is performed after the last interpolation or smoothing 1583 step. 1584 1585 **Remark:** All passes assume that the estimation error covariance matrix 1586 (:math:`\mathbf{\Sigma}`, :math:`\tilde{\mathbf{\Sigma}}^{(2)}`, or :math:`\tilde{\mathbf{\Sigma}}^{(4)}`) is diagonal, which 1587 may not be accurate. When this assumption does not hold, this interpolator is only 1588 an approximation of LMMSE interpolation. 1589 1590 **Remark:** The order in which frequency interpolation, temporal 1591 interpolation, and, optionally, spatial smoothing are applied, is controlled using the 1592 ``order`` parameter. 1593 1594 Note 1595 ---- 1596 This layer does not support graph mode with XLA. 1597 1598 Parameters 1599 ---------- 1600 pilot_pattern : PilotPattern 1601 An instance of :class:`~sionna.ofdm.PilotPattern` 1602 1603 cov_mat_time : [num_ofdm_symbols, num_ofdm_symbols], tf.complex 1604 Time covariance matrix of the channel 1605 1606 cov_mat_freq : [fft_size, fft_size], tf.complex 1607 Frequency covariance matrix of the channel 1608 1609 cov_time_space : [num_rx_ant, num_rx_ant], tf.complex 1610 Spatial covariance matrix of the channel. 1611 Defaults to `None`. 1612 Only required if spatial smoothing is requested (see ``order``). 1613 1614 order : str 1615 Order in which to perform interpolation and optional smoothing. 1616 For example, ``"t-f-s"`` means that interpolation across the OFDM symbols 1617 is performed first (``"t"``: time), followed by interpolation across the 1618 sub-carriers (``"f"``: frequency), and finally smoothing across the 1619 receive antennas (``"s"``: space). 1620 Similarly, ``"f-t"`` means interpolation across the sub-carriers followed 1621 by interpolation across the OFDM symbols and no spatial smoothing. 1622 The spatial covariance matrix (``cov_time_space``) is only required when 1623 spatial smoothing is requested. 1624 Time and frequency interpolation are not optional to ensure that a channel 1625 estimate is computed for all resource elements. 1626 1627 Input 1628 ----- 1629 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 1630 Channel estimates for the pilot-carrying resource elements 1631 1632 err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex 1633 Channel estimation error variances for the pilot-carrying resource elements 1634 1635 Output 1636 ------ 1637 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex 1638 Channel estimates accross the entire resource grid for all 1639 transmitters and streams 1640 1641 err_var : Same shape as ``h_hat``, tf.float 1642 Channel estimation error variances accross the entire resource grid 1643 for all transmitters and streams 1644 """ 1645 1646 def __init__(self, pilot_pattern, cov_mat_time, cov_mat_freq, 1647 cov_mat_space=None, order='t-f'): 1648 1649 # Check the specified order 1650 order = order.split('-') 1651 assert 2 <= len(order) <= 3, "Invalid order for interpolation." 1652 spatial_smoothing = False 1653 freq_smoothing = False 1654 time_smoothing = False 1655 for o in order: 1656 assert o in ('s', 'f', 't'), f"Uknown dimension {o}" 1657 if o == 's': 1658 assert not spatial_smoothing,\ 1659 "Spatial smoothing can be specified at most once" 1660 spatial_smoothing = True 1661 elif o == 't': 1662 assert not time_smoothing,\ 1663 "Temporal interpolation can be specified once only" 1664 time_smoothing = True 1665 elif o == 'f': 1666 assert not freq_smoothing,\ 1667 "Frequency interpolation can be specified once only" 1668 freq_smoothing = True 1669 if spatial_smoothing: 1670 assert cov_mat_space is not None,\ 1671 "A spatial covariance matrix is required for spatial smoothing" 1672 assert freq_smoothing, "Frequency interpolation is required" 1673 assert time_smoothing, "Time interpolation is required" 1674 1675 self._order = order 1676 self._num_ofdm_symbols = pilot_pattern.num_ofdm_symbols 1677 self._num_effective_subcarriers =pilot_pattern.num_effective_subcarriers 1678 1679 # Build pilot masks for every stream 1680 pilot_mask = self._build_pilot_mask(pilot_pattern) 1681 1682 # Build indices for mapping channel estimates and 1683 # error variances that are given as input to a 1684 # resource grid 1685 num_pilots = pilot_pattern.pilots.shape[2] 1686 inputs_to_rg_indices = self._build_inputs2rg_indices(pilot_mask, 1687 num_pilots) 1688 self._inputs_to_rg_indices = tf.cast(inputs_to_rg_indices, tf.int32) 1689 1690 # 1D interpolator according to requested order 1691 # Interpolation is always performed along the inner dimension. 1692 interpolators = [] 1693 # Masks for masking error variances that were not updated 1694 err_var_masks = [] 1695 for i, o in enumerate(order): 1696 # Is it the last one? 1697 last_step = i == len(order)-1 1698 # Frequency 1699 if o == "f": 1700 interpolator = LMMSEInterpolator1D(pilot_mask, cov_mat_freq, 1701 last_step=last_step) 1702 pilot_mask = self._update_pilot_mask_interp(pilot_mask) 1703 err_var_mask = tf.cast(pilot_mask == 1, 1704 cov_mat_freq.dtype.real_dtype) 1705 # Time 1706 elif o == 't': 1707 pilot_mask = tf.transpose(pilot_mask, [0, 1, 3, 2]) 1708 interpolator = LMMSEInterpolator1D(pilot_mask, cov_mat_time, 1709 last_step=last_step) 1710 pilot_mask = self._update_pilot_mask_interp(pilot_mask) 1711 pilot_mask = tf.transpose(pilot_mask, [0, 1, 3, 2]) 1712 err_var_mask = tf.cast(pilot_mask == 1, 1713 cov_mat_freq.dtype.real_dtype) 1714 # Space 1715 else: 1716 interpolator = SpatialChannelFilter(cov_mat_space, 1717 last_step=last_step) 1718 err_var_mask = tf.cast(pilot_mask == 1, 1719 cov_mat_freq.dtype.real_dtype) 1720 interpolators.append(interpolator) 1721 err_var_masks.append(err_var_mask) 1722 self._interpolators = interpolators 1723 self._err_var_masks = err_var_masks 1724 1725 def _build_pilot_mask(self, pilot_pattern): 1726 """ 1727 Build for every transmitter and stream a pilot mask indicating 1728 which REs are allocated to pilots, data, or not used. 1729 # 0 -> Data 1730 # 1 -> Pilot 1731 # 2 -> Not used 1732 """ 1733 1734 mask = pilot_pattern.mask 1735 pilots = pilot_pattern.pilots 1736 num_tx = mask.shape[0] 1737 num_streams_per_tx = mask.shape[1] 1738 num_ofdm_symbols = mask.shape[2] 1739 num_effective_subcarriers = mask.shape[3] 1740 1741 pilot_mask = np.zeros([num_tx, num_streams_per_tx, num_ofdm_symbols, 1742 num_effective_subcarriers], int) 1743 for tx,st in itertools.product( range(num_tx), 1744 range(num_streams_per_tx)): 1745 pil_index = 0 1746 for sb,sc in itertools.product( range(num_ofdm_symbols), 1747 range(num_effective_subcarriers)): 1748 if mask[tx,st,sb,sc] == 1: 1749 if np.abs(pilots[tx,st,pil_index]) > 0.0: 1750 pilot_mask[tx,st,sb,sc] = 1 1751 else: 1752 pilot_mask[tx,st,sb,sc] = 2 1753 pil_index += 1 1754 1755 return pilot_mask 1756 1757 def _build_inputs2rg_indices(self, pilot_mask, num_pilots): 1758 """ 1759 Builds indices for mapping channel estimates and 1760 error variances that are given as input to a 1761 resource grid 1762 """ 1763 1764 num_tx = pilot_mask.shape[0] 1765 num_streams_per_tx = pilot_mask.shape[1] 1766 num_ofdm_symbols = pilot_mask.shape[2] 1767 num_effective_subcarriers = pilot_mask.shape[3] 1768 1769 inputs_to_rg_indices = np.zeros([num_tx, num_streams_per_tx, 1770 num_pilots, 4], int) 1771 for tx,st in itertools.product( range(num_tx), 1772 range(num_streams_per_tx)): 1773 pil_index = 0 # Pilot index for this stream 1774 for sb,sc in itertools.product( range(num_ofdm_symbols), 1775 range(num_effective_subcarriers)): 1776 if pilot_mask[tx,st,sb,sc] == 0: 1777 continue 1778 if pilot_mask[tx,st,sb,sc] == 1: 1779 inputs_to_rg_indices[tx, st, pil_index] = [tx, st, sb, sc] 1780 pil_index += 1 1781 1782 return inputs_to_rg_indices 1783 1784 def _update_pilot_mask_interp(self, pilot_mask): 1785 """ 1786 Update the pilot mask to label the resource elements for which the 1787 channel was interpolated. 1788 """ 1789 1790 interpolated = np.any(pilot_mask == 1, axis=-1, keepdims=True) 1791 pilot_mask = np.where(interpolated, 1, pilot_mask) 1792 1793 return pilot_mask 1794 1795 def __call__(self, h_hat, err_var): 1796 1797 # h_hat : [batch_size, num_rx, num_rx_ant, num_tx, 1798 # num_streams_per_tx, num_pilots] 1799 # err_var : [batch_size, num_rx, num_rx_ant, num_tx, 1800 # num_streams_per_tx, num_pilots] 1801 1802 batch_size = tf.shape(h_hat)[0] 1803 num_rx = tf.shape(h_hat)[1] 1804 num_rx_ant = tf.shape(h_hat)[2] 1805 num_tx = tf.shape(h_hat)[3] 1806 num_tx_stream = tf.shape(h_hat)[4] 1807 num_ofdm_symbols = self._num_ofdm_symbols 1808 num_effective_subcarriers = self._num_effective_subcarriers 1809 1810 # For some estimator, err_var might not have the same shape 1811 # as h_hat 1812 err_var = tf.broadcast_to(err_var, tf.shape(h_hat)) 1813 1814 # Mapping the channel estimates and error variances to a resource grid 1815 # all : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1816 # num_ofdm_symbols, num_effective_subcarriers] 1817 h_hat = tf.transpose(h_hat, [3, 4, 5, 0, 1, 2]) 1818 err_var = tf.transpose(err_var, [3, 4, 5, 0, 1, 2]) 1819 h_hat = tf.scatter_nd(self._inputs_to_rg_indices, h_hat, 1820 [num_tx, num_tx_stream, 1821 num_ofdm_symbols, 1822 num_effective_subcarriers, 1823 batch_size, num_rx, num_rx_ant]) 1824 err_var = tf.scatter_nd(self._inputs_to_rg_indices, err_var, 1825 [num_tx, num_tx_stream, 1826 num_ofdm_symbols, 1827 num_effective_subcarriers, 1828 batch_size, num_rx, num_rx_ant]) 1829 h_hat = tf.transpose(h_hat, [4, 5, 6, 0, 1, 2, 3]) 1830 err_var = tf.transpose(err_var, [4, 5, 6, 0, 1, 2, 3]) 1831 1832 # Interpolation 1833 # Performed according to the requested order. Transpose are used as 1834 # 1D interpolation is performed along the inner axis. 1835 items = zip(self._order, self._interpolators, self._err_var_masks) 1836 for o,interp,err_var_mask in items: 1837 # Frequency 1838 if o == 'f': 1839 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1840 # num_ofdm_symbols, num_effective_subcarriers] 1841 h_hat, err_var = interp(h_hat, err_var) 1842 err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0) 1843 err_var = err_var*err_var_mask 1844 # Time 1845 elif o == 't': 1846 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1847 # num_effective_subcarriers, num_ofdm_symbols] 1848 h_hat = tf.transpose(h_hat, [0, 1, 2, 3, 4, 6, 5]) 1849 err_var = tf.transpose(err_var, [0, 1, 2, 3, 4, 6, 5]) 1850 h_hat, err_var = interp(h_hat, err_var) 1851 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, 1852 # num_ofdm_symbols, num_effective_subcarriers] 1853 h_hat = tf.transpose(h_hat, [0, 1, 2, 3, 4, 6, 5]) 1854 err_var = tf.transpose(err_var, [0, 1, 2, 3, 4, 6, 5]) 1855 err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0) 1856 err_var = err_var*err_var_mask 1857 # Space 1858 elif o == 's': 1859 # [batch_size, num_rx, num_tx, num_streams_per_tx, 1860 # num_ofdm_symbols, num_effective_subcarriers, num_rx_ant] 1861 h_hat = tf.transpose(h_hat, [0, 1, 3, 4, 5, 6, 2]) 1862 err_var = tf.transpose(err_var, [0, 1, 3, 4, 5, 6, 2]) 1863 h_hat, err_var = interp(h_hat, err_var) 1864 # [batch_size, num_rx, num_tx, num_streams_per_tx, 1865 # num_ofdm_symbols, num_effective_subcarriers, num_rx_ant] 1866 h_hat = tf.transpose(h_hat, [0, 1, 6, 2, 3, 4, 5]) 1867 err_var = tf.transpose(err_var, [0, 1, 6, 2, 3, 4, 5]) 1868 err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0) 1869 err_var = err_var*err_var_mask 1870 1871 return h_hat, err_var 1872 1873 ####################################################### 1874 # Utilities 1875 ####################################################### 1876 1877 def tdl_freq_cov_mat(model, subcarrier_spacing, fft_size, delay_spread, 1878 dtype=tf.complex64): 1879 # pylint: disable=line-too-long 1880 r""" 1881 Computes the frequency covariance matrix of a 1882 :class:`~sionna.channel.tr38901.TDL` channel model. 1883 1884 The channel frequency covariance matrix :math:`\mathbf{R}^{(f)}` of a TDL channel model is 1885 1886 .. math:: 1887 \mathbf{R}^{(f)}_{u,v} = \sum_{\ell=1}^L P_\ell e^{-j 2 \pi \tau_\ell \Delta_f (u-v)}, 1 \leq u,v \leq M 1888 1889 where :math:`M` is the FFT size, :math:`L` is the number of paths for the selected TDL model, 1890 :math:`P_\ell` and :math:`\tau_\ell` are the average power and delay for the 1891 :math:`\ell^{\text{th}}` path, respectively, and :math:`\Delta_f` is the sub-carrier spacing. 1892 1893 Input 1894 ------ 1895 model : str 1896 TDL model for which to return the covariance matrix. 1897 Should be one of "A", "B", "C", "D", or "E". 1898 1899 subcarrier_spacing : float 1900 Sub-carrier spacing [Hz] 1901 1902 fft_size : float 1903 FFT size 1904 1905 delay_spread : float 1906 Delay spread [s] 1907 1908 dtype : tf.DType 1909 Datatype to use for the output. 1910 Should be one of `tf.complex64` or `tf.complex128`. 1911 Defaults to `tf.complex64`. 1912 1913 Output 1914 ------ 1915 cov_mat : [fft_size, fft_size], tf.complex 1916 Channel frequency covariance matrix 1917 """ 1918 1919 assert dtype in (tf.complex64, tf.complex128),\ 1920 "The `dtype` should be a complex datatype" 1921 1922 # 1923 # Load the power delay profile 1924 # 1925 1926 # Set the file from which to load the model 1927 assert model in ('A', 'B', 'C', 'D', 'E'), "Invalid TDL model" 1928 if model == 'A': 1929 parameters_fname = "TDL-A.json" 1930 elif model == 'B': 1931 parameters_fname = "TDL-B.json" 1932 elif model == 'C': 1933 parameters_fname = "TDL-C.json" 1934 elif model == 'D': 1935 parameters_fname = "TDL-D.json" 1936 else: # 'E' 1937 parameters_fname = "TDL-E.json" 1938 source = files(models).joinpath(parameters_fname) 1939 # pylint: disable=unspecified-encoding 1940 with open(source) as parameter_file: 1941 params = json.load(parameter_file) 1942 # LoS scenario ? 1943 los = bool(params['los']) 1944 # Retrieve power and delays 1945 delays = np.array(params['delays'])*delay_spread 1946 mean_powers = np.power(10.0, np.array(params['powers'])/10.0) 1947 1948 if los: 1949 # Add the power of the specular and non-specular component of 1950 # the first path 1951 mean_powers[0] = mean_powers[0] + mean_powers[1] 1952 mean_powers = np.concatenate([mean_powers[:1], mean_powers[2:]], axis=0) 1953 # The first two paths have 0 delays as they correspond to the 1954 # specular and reflected components of the first path. 1955 delays = delays[1:] 1956 1957 # Normalize the PDP 1958 norm_factor = np.sum(mean_powers) 1959 mean_powers = mean_powers / norm_factor 1960 1961 # 1962 # Build frequency covariance matrix 1963 # 1964 1965 n = np.arange(fft_size) 1966 p = -2.*np.pi*subcarrier_spacing*n 1967 p = np.expand_dims(p, axis=0) 1968 delays = np.expand_dims(delays, axis=1) 1969 p = p*delays 1970 p = np.exp(1j*p) 1971 p = np.expand_dims(p, axis=-1) 1972 cov_mat = np.matmul(p, np.transpose(np.conj(p), [0, 2, 1])) 1973 mean_powers = np.expand_dims(mean_powers, axis=(1,2)) 1974 cov_mat = np.sum(mean_powers*cov_mat, axis=0) 1975 1976 return tf.cast(cov_mat, dtype) 1977 1978 def tdl_time_cov_mat(model, speed, carrier_frequency, ofdm_symbol_duration, 1979 num_ofdm_symbols, los_angle_of_arrival=PI/4., dtype=tf.complex64): 1980 # pylint: disable=line-too-long 1981 r""" 1982 Computes the time covariance matrix of a 1983 :class:`~sionna.channel.tr38901.TDL` channel model. 1984 1985 For non-line-of-sight (NLoS) model, the channel time covariance matrix 1986 :math:`\mathbf{R^{(t)}}` of a TDL channel model is 1987 1988 .. math:: 1989 \mathbf{R^{(t)}}_{u,v} = J_0 \left( \nu \Delta_t \left( u-v \right) \right) 1990 1991 where :math:`J_0` is the zero-order Bessel function of the first kind, 1992 :math:`\Delta_t` the duration of an OFDM symbol, and :math:`\nu` the Doppler 1993 spread defined by 1994 1995 .. math:: 1996 \nu = 2 \pi \frac{v}{c} f_c 1997 1998 where :math:`v` is the movement speed, :math:`c` the speed of light, and 1999 :math:`f_c` the carrier frequency. 2000 2001 For line-of-sight (LoS) channel models, the channel time covariance matrix 2002 is 2003 2004 .. math:: 2005 \mathbf{R^{(t)}}_{u,v} = P_{\text{NLoS}} J_0 \left( \nu \Delta_t \left( u-v \right) \right) + P_{\text{LoS}}e^{j \nu \Delta_t \left( u-v \right) \cos{\alpha_{\text{LoS}}}} 2006 2007 where :math:`\alpha_{\text{LoS}}` is the angle-of-arrival for the LoS path, 2008 :math:`P_{\text{NLoS}}` the total power of NLoS paths, and 2009 :math:`P_{\text{LoS}}` the power of the LoS path. The power delay profile 2010 is assumed to have unit power, i.e., :math:`P_{\text{NLoS}} + P_{\text{LoS}} = 1`. 2011 2012 Input 2013 ------ 2014 model : str 2015 TDL model for which to return the covariance matrix. 2016 Should be one of "A", "B", "C", "D", or "E". 2017 2018 speed : float 2019 Speed [m/s] 2020 2021 carrier_frequency : float 2022 Carrier frequency [Hz] 2023 2024 ofdm_symbol_duration : float 2025 Duration of an OFDM symbol [s] 2026 2027 num_ofdm_symbols : int 2028 Number of OFDM symbols 2029 2030 los_angle_of_arrival : float 2031 Angle-of-arrival for LoS path [radian]. Only used with LoS models. 2032 Defaults to :math:`\pi/4`. 2033 2034 dtype : tf.DType 2035 Datatype to use for the output. 2036 Should be one of `tf.complex64` or `tf.complex128`. 2037 Defaults to `tf.complex64`. 2038 2039 Output 2040 ------ 2041 cov_mat : [num_ofdm_symbols, num_ofdm_symbols], tf.complex 2042 Channel time covariance matrix 2043 """ 2044 2045 # Doppler spread 2046 doppler_spread = 2.*PI*speed/SPEED_OF_LIGHT*carrier_frequency 2047 2048 # 2049 # Load the power delay profile 2050 # 2051 2052 # Set the file from which to load the model 2053 assert model in ('A', 'B', 'C', 'D', 'E'), "Invalid TDL model" 2054 if model == 'A': 2055 parameters_fname = "TDL-A.json" 2056 elif model == 'B': 2057 parameters_fname = "TDL-B.json" 2058 elif model == 'C': 2059 parameters_fname = "TDL-C.json" 2060 elif model == 'D': 2061 parameters_fname = "TDL-D.json" 2062 else: # 'E' 2063 parameters_fname = "TDL-E.json" 2064 source = files(models).joinpath(parameters_fname) 2065 # pylint: disable=unspecified-encoding 2066 with open(source) as parameter_file: 2067 params = json.load(parameter_file) 2068 # LoS scenario ? 2069 los = bool(params['los']) 2070 # Retrieve power and delays 2071 mean_powers = np.power(10.0, np.array(params['powers'])/10.0) 2072 2073 # Normalize the PDP 2074 norm_factor = np.sum(mean_powers) 2075 mean_powers = mean_powers / norm_factor 2076 2077 if los: 2078 los_power = mean_powers[0] 2079 nlos_power = np.sum(mean_powers[1:]) 2080 else: 2081 nlos_power = np.sum(mean_powers) 2082 2083 # 2084 # Build time covariance matrix 2085 # 2086 2087 indices = np.arange(num_ofdm_symbols) 2088 s1 = np.expand_dims(indices, axis=1) 2089 s2 = np.expand_dims(indices, axis=0) 2090 exp = doppler_spread*ofdm_symbol_duration*(s1-s2) 2091 cov_mat_nlos = jv(0.0, exp)*nlos_power 2092 if los: 2093 cov_mat_los = np.exp(1j*exp*np.cos(los_angle_of_arrival))*los_power 2094 cov_mat = cov_mat_nlos+cov_mat_los 2095 else: 2096 cov_mat = cov_mat_nlos 2097 2098 return tf.cast(cov_mat, dtype)