pusch_pilot_pattern.py (4253B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """PUSCH pilot pattern for the nr (5G) sub-package of the Sionna library. 6 """ 7 import warnings 8 from collections.abc import Sequence 9 import tensorflow as tf 10 import numpy as np 11 from sionna.ofdm import PilotPattern 12 from .pusch_config import PUSCHConfig 13 14 class PUSCHPilotPattern(PilotPattern): 15 # pylint: disable=line-too-long 16 r"""Class defining a pilot pattern for NR PUSCH. 17 18 This class defines a :class:`~sionna.ofdm.PilotPattern` 19 that is used to configure an OFDM :class:`~sionna.ofdm.ResourceGrid`. 20 21 For every transmitter, a separte :class:`~sionna.nr.PUSCHConfig` 22 needs to be provided from which the pilot pattern will be created. 23 24 Parameters 25 ---------- 26 pusch_configs : instance or list of :class:`~sionna.nr.PUSCHConfig` 27 PUSCH Configurations according to which the pilot pattern 28 will created. One configuration is needed for each transmitter. 29 30 dtype : tf.Dtype 31 Defines the datatype for internal calculations and the output 32 dtype. Defaults to `tf.complex64`. 33 """ 34 def __init__(self, 35 pusch_configs, 36 dtype=tf.complex64): 37 38 # Check correct type of pusch_configs 39 if isinstance(pusch_configs, PUSCHConfig): 40 pusch_configs = [pusch_configs] 41 elif isinstance(pusch_configs, Sequence): 42 for c in pusch_configs: 43 assert isinstance(c, PUSCHConfig), \ 44 "Each element of pusch_configs must be a valide PUSCHConfig" 45 else: 46 raise ValueError("Invalid value for pusch_configs") 47 48 # Check validity of provided pusch_configs 49 num_tx = len(pusch_configs) 50 num_streams_per_tx = pusch_configs[0].num_layers 51 dmrs_grid = pusch_configs[0].dmrs_grid 52 num_subcarriers = dmrs_grid[0].shape[0] 53 num_ofdm_symbols = pusch_configs[0].l_d 54 precoding = pusch_configs[0].precoding 55 dmrs_ports = [] 56 num_pilots = np.sum(pusch_configs[0].dmrs_mask) 57 for pusch_config in pusch_configs: 58 assert pusch_config.num_layers==num_streams_per_tx, \ 59 "All pusch_configs must have the same number of layers" 60 assert pusch_config.dmrs_grid[0].shape[0]==num_subcarriers, \ 61 "All pusch_configs must have the same number of subcarriers" 62 assert pusch_config.l_d==num_ofdm_symbols, \ 63 "All pusch_configs must have the same number of OFDM symbols" 64 assert pusch_config.precoding==precoding, \ 65 "All pusch_configs must have a the same precoding method" 66 assert np.sum(pusch_config.dmrs_mask)==num_pilots, \ 67 "All pusch_configs must have a the same number of masked REs" 68 with warnings.catch_warnings(): 69 warnings.simplefilter('always') 70 for port in pusch_config.dmrs.dmrs_port_set: 71 if port in dmrs_ports: 72 msg = f"DMRS port {port} used by multiple transmitters" 73 warnings.warn(msg) 74 dmrs_ports += pusch_config.dmrs.dmrs_port_set 75 76 # Create mask and pilots tensors 77 mask = np.zeros([num_tx, 78 num_streams_per_tx, 79 num_ofdm_symbols, 80 num_subcarriers], bool) 81 num_pilots = np.sum(pusch_configs[0].dmrs_mask) 82 pilots = np.zeros([num_tx, num_streams_per_tx, num_pilots], complex) 83 for i, pusch_config in enumerate(pusch_configs): 84 for j in range(num_streams_per_tx): 85 ind0, ind1 = pusch_config.symbol_allocation 86 mask[i,j] = np.transpose( 87 pusch_config.dmrs_mask[:, ind0:ind0+ind1]) 88 dmrs_grid = np.transpose( 89 pusch_config.dmrs_grid[j, :, ind0:ind0+ind1]) 90 pilots[i,j] = dmrs_grid[np.where(mask[i,j])] 91 92 # Init PilotPattern class 93 super().__init__(mask, pilots, 94 trainable=False, 95 normalize=False, 96 dtype=dtype)