apply_time_channel.py (6612B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for applying channel responses to channel inputs in the time domain""" 6 7 import tensorflow as tf 8 9 import numpy as np 10 11 import scipy 12 13 from sionna.utils import insert_dims 14 from .awgn import AWGN 15 16 class ApplyTimeChannel(tf.keras.layers.Layer): 17 # pylint: disable=line-too-long 18 r"""ApplyTimeChannel(num_time_samples, l_tot, add_awgn=True, dtype=tf.complex64, **kwargs) 19 20 Apply time domain channel responses ``h_time`` to channel inputs ``x``, 21 by filtering the channel inputs with time-variant channel responses. 22 23 This class inherits from the Keras `Layer` class and can be used as layer 24 in a Keras model. 25 26 For each batch example, ``num_time_samples`` + ``l_tot`` - 1 time steps of a 27 channel realization are required to filter the channel inputs. 28 29 The channel output consists of ``num_time_samples`` + ``l_tot`` - 1 30 time samples, as it is the result of filtering the channel input of length 31 ``num_time_samples`` with the time-variant channel filter of length 32 ``l_tot``. In the case of a single-input single-output link and given a sequence of channel 33 inputs :math:`x_0,\cdots,x_{N_B}`, where :math:`N_B` is ``num_time_samples``, this 34 layer outputs 35 36 .. math:: 37 y_b = \sum_{\ell = 0}^{L_{\text{tot}}} x_{b-\ell} \bar{h}_{b,\ell} + w_b 38 39 where :math:`L_{\text{tot}}` corresponds ``l_tot``, :math:`w_b` to the additive noise, and 40 :math:`\bar{h}_{b,\ell}` to the :math:`\ell^{th}` tap of the :math:`b^{th}` channel sample. 41 This layer outputs :math:`y_b` for :math:`b` ranging from 0 to 42 :math:`N_B + L_{\text{tot}} - 1`, and :math:`x_{b}` is set to 0 for :math:`b \geq N_B`. 43 44 For multiple-input multiple-output (MIMO) links, the channel output is computed for each antenna 45 of each receiver and by summing over all the antennas of all transmitters. 46 47 Parameters 48 ---------- 49 50 num_time_samples : int 51 Number of time samples forming the channel input (:math:`N_B`) 52 53 l_tot : int 54 Length of the channel filter (:math:`L_{\text{tot}} = L_{\text{max}} - L_{\text{min}} + 1`) 55 56 add_awgn : bool 57 If set to `False`, no white Gaussian noise is added. 58 Defaults to `True`. 59 60 dtype : tf.DType 61 Complex datatype to use for internal processing and output. 62 Defaults to `tf.complex64`. 63 64 Input 65 ----- 66 67 (x, h_time, no) or (x, h_time): 68 Tuple: 69 70 x : [batch size, num_tx, num_tx_ant, num_time_samples], tf.complex 71 Channel inputs 72 73 h_time : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_tot - 1, l_tot], tf.complex 74 Channel responses. 75 For each batch example, ``num_time_samples`` + ``l_tot`` - 1 time steps of a 76 channel realization are required to filter the channel inputs. 77 78 no : Scalar or Tensor, tf.float 79 Scalar or tensor whose shape can be broadcast to the shape of the channel outputs: [batch size, num_rx, num_rx_ant, num_time_samples + l_tot - 1]. 80 Only required if ``add_awgn`` is set to `True`. 81 The noise power ``no`` is per complex dimension. If ``no`` is a 82 scalar, noise of the same variance will be added to the outputs. 83 If ``no`` is a tensor, it must have a shape that can be broadcast to 84 the shape of the channel outputs. This allows, e.g., adding noise of 85 different variance to each example in a batch. If ``no`` has a lower 86 rank than the channel outputs, then ``no`` will be broadcast to the 87 shape of the channel outputs by adding dummy dimensions after the 88 last axis. 89 90 Output 91 ------- 92 y : [batch size, num_rx, num_rx_ant, num_time_samples + l_tot - 1], tf.complex 93 Channel outputs. 94 The channel output consists of ``num_time_samples`` + ``l_tot`` - 1 95 time samples, as it is the result of filtering the channel input of length 96 ``num_time_samples`` with the time-variant channel filter of length 97 ``l_tot``. 98 """ 99 100 def __init__(self, num_time_samples, l_tot, add_awgn=True, 101 dtype=tf.complex64, **kwargs): 102 103 super().__init__(trainable=False, dtype=dtype, **kwargs) 104 105 self._add_awgn = add_awgn 106 107 # The channel transfert function is implemented by first gathering from 108 # the vector of transmitted baseband symbols 109 # x = [x_0,...,x_{num_time_samples-1}]^T the symbols that are then 110 # multiplied by the channel tap coefficients. 111 # We build here the matrix of indices G, with size 112 # `num_time_samples + l_tot - 1` x `l_tot` that is used to perform this 113 # gathering. 114 # For example, if there are 4 channel taps 115 # h = [h_0, h_1, h_2, h_3]^T 116 # and `num_time_samples` = 10 time steps then G would be 117 # [[0, 10, 10, 10] 118 # [1, 0, 10, 10] 119 # [2, 1, 0, 10] 120 # [3, 2, 1, 0] 121 # [4, 3, 2, 1] 122 # [5, 4, 3, 2] 123 # [6, 5, 4, 3] 124 # [7, 6, 5, 4] 125 # [8, 7, 6, 5] 126 # [9, 8, 7, 6] 127 # [10, 9, 8, 7] 128 # [10,10, 9, 8] 129 # [10,10, 10, 9] 130 # Note that G is a Toeplitz matrix. 131 # In this example, the index `num_time_samples`=10 corresponds to the 132 # zero symbol. The vector of transmitted symbols is padded with one 133 # zero at the end. 134 first_colum = np.concatenate([ np.arange(0, num_time_samples), 135 np.full([l_tot-1], num_time_samples)]) 136 first_row = np.concatenate([[0], np.full([l_tot-1], num_time_samples)]) 137 self._g = scipy.linalg.toeplitz(first_colum, first_row) 138 139 def build(self, input_shape): #pylint: disable=unused-argument 140 141 if self._add_awgn: 142 self._awgn = AWGN(dtype=self.dtype) 143 144 def call(self, inputs): 145 146 if self._add_awgn: 147 x, h_time, no = inputs 148 else: 149 x, h_time = inputs 150 151 # Preparing the channel input for broadcasting and matrix multiplication 152 x = tf.pad(x, [[0,0], [0,0], [0,0], [0,1]]) 153 x = insert_dims(x, 2, axis=1) 154 155 x = tf.gather(x, self._g, axis=-1) 156 157 # Apply the channel response 158 y = tf.reduce_sum(h_time*x, axis=-1) 159 y = tf.reduce_sum(tf.reduce_sum(y, axis=4), axis=3) 160 161 # Add AWGN if requested 162 if self._add_awgn: 163 y = self._awgn((y, no)) 164 165 return y