modulator.py (4423B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition for the OFDM Modulator""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from tensorflow.signal import ifftshift 10 from sionna.utils import flatten_last_dims 11 from sionna.signal import ifft 12 13 14 class OFDMModulator(Layer): 15 # pylint: disable=line-too-long 16 """ 17 OFDMModulator(cyclic_prefix_length=0, **kwargs) 18 19 Computes the time-domain representation of an OFDM resource grid 20 with (optional) cyclic prefix 21 22 Parameters 23 ---------- 24 cyclic_prefix_length : scalar or [num_ofdm_symbols], int 25 Integer or vector of integers indicating the length of the 26 cyclic prefix that is prepended to each OFDM symbol. None of its 27 elements can be larger than the FFT size. 28 Defaults to 0. 29 30 Input 31 ----- 32 : [...,num_ofdm_symbols,fft_size], tf.complex 33 Resource grid in the frequency domain 34 35 Output 36 ------ 37 : [...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)], tf.complex 38 Time-domain OFDM signal 39 """ 40 41 def __init__(self, cyclic_prefix_length=0, **kwargs): 42 super().__init__(**kwargs) 43 self._cyclic_prefix_length = None 44 self.cyclic_prefix_length = cyclic_prefix_length 45 46 @property 47 def cyclic_prefix_length(self): 48 """ 49 scalar or [num_ofdm_symbols], int : Get/set the cyclic prefix length 50 """ 51 return self._cyclic_prefix_length 52 53 @cyclic_prefix_length.setter 54 def cyclic_prefix_length(self, value): 55 value = tf.cast(value, tf.int32) 56 if not tf.reduce_all(value>=0): 57 msg = "`cyclic_prefix_length` must be nonnegative." 58 raise ValueError(msg) 59 if not 0<= tf.rank(value)<=1: 60 msg = "`cyclic_prefix_length` must be of rank 0 or 1" 61 raise ValueError(msg) 62 self._cyclic_prefix_length = value 63 64 def build(self, input_shape): 65 num_ofdm_symbols, fft_size = input_shape[-2:] 66 if not tf.reduce_all(self.cyclic_prefix_length<=fft_size): 67 msg = "`cyclic_prefix_length` cannot be larger than `fft_size`." 68 raise ValueError(msg) 69 if len(self.cyclic_prefix_length.shape)==1: 70 if not self.cyclic_prefix_length.shape[0]==num_ofdm_symbols: 71 msg = "`cyclic_prefix_length` must be of size [num_ofdm_symbols]" 72 raise ValueError(msg) 73 74 # Compute indices of CP symbols 75 # These are offset by the number of the OFDM symbol 76 # [num_ofdm_symbols, 1] 77 offsets = tf.expand_dims(tf.range(1, num_ofdm_symbols+1)*fft_size, 78 1) 79 # [num_ofdm_symbols, None] (ragged tensor) 80 cp_ind = tf.ragged.range(starts=-self.cyclic_prefix_length, 81 limits=0) + offsets 82 83 # Compute indices of symbols containing the actual sequence 84 # [num_ofdm_symbols, fft_size] 85 data_ind = tf.repeat(tf.expand_dims(tf.range(0, fft_size), 0), 86 num_ofdm_symbols, 0) + offsets - fft_size 87 88 # Concat CP and sequence indices 89 # [num_ofdm_symbols, None] 90 ind = tf.concat([cp_ind, data_ind], axis=-1) 91 92 # Flatten in time domain 93 # [num_ofdm_symbols *fft_size + sum(cyclic_prefix_length)] 94 self._ind = ind.flat_values 95 96 def call(self, inputs): 97 98 # Shift DC subcarrier to first position 99 x_freq = ifftshift(inputs, axes=-1) 100 101 # Compute IFFT along the last dimension 102 x_time = ifft(x_freq) 103 104 if len(self.cyclic_prefix_length.shape)==1: 105 # Individual CP length per OFDM symbol 106 107 # Flatten last two dimensions 108 x_time = flatten_last_dims(x_time, 2) 109 110 # Gather full time-domain signal 111 return tf.gather(x_time, self._ind, axis=-1) 112 113 else: 114 # Same CP length for all OFDM symbols 115 116 # Obtain cyclic prefix 117 cp = x_time[...,tf.shape(x_time)[-1]-self._cyclic_prefix_length:] 118 119 # Prepend cyclic prefix 120 x_time = tf.concat([cp, x_time], -1) 121 122 # Serialize last two dimensions 123 return flatten_last_dims(x_time, 2)