awgn.py (2679B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for simulating an AWGN channel""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from sionna.utils import expand_to_rank, complex_normal 10 11 class AWGN(Layer): 12 r"""AWGN(dtype=tf.complex64, **kwargs) 13 14 Add complex AWGN to the inputs with a certain variance. 15 16 This class inherits from the Keras `Layer` class and can be used as layer in 17 a Keras model. 18 19 This layer adds complex AWGN noise with variance ``no`` to the input. 20 The noise has variance ``no/2`` per real dimension. 21 It can be either a scalar or a tensor which can be broadcast to the shape 22 of the input. 23 24 Example 25 -------- 26 27 Setting-up: 28 29 >>> awgn_channel = AWGN() 30 31 Running: 32 33 >>> # x is the channel input 34 >>> # no is the noise variance 35 >>> y = awgn_channel((x, no)) 36 37 Parameters 38 ---------- 39 dtype : Complex tf.DType 40 Defines the datatype for internal calculations and the output 41 dtype. Defaults to `tf.complex64`. 42 43 Input 44 ----- 45 46 (x, no) : 47 Tuple: 48 49 x : Tensor, tf.complex 50 Channel input 51 52 no : Scalar or Tensor, tf.float 53 Scalar or tensor whose shape can be broadcast to the shape of ``x``. 54 The noise power ``no`` is per complex dimension. If ``no`` is a 55 scalar, noise of the same variance will be added to the input. 56 If ``no`` is a tensor, it must have a shape that can be broadcast to 57 the shape of ``x``. This allows, e.g., adding noise of different 58 variance to each example in a batch. If ``no`` has a lower rank than 59 ``x``, then ``no`` will be broadcast to the shape of ``x`` by adding 60 dummy dimensions after the last axis. 61 62 Output 63 ------- 64 y : Tensor with same shape as ``x``, tf.complex 65 Channel output 66 """ 67 68 def __init__(self, dtype=tf.complex64, **kwargs): 69 super().__init__(dtype=dtype, **kwargs) 70 self._real_dtype = tf.dtypes.as_dtype(self._dtype).real_dtype 71 72 def call(self, inputs): 73 74 x, no = inputs 75 76 # Create tensors of real-valued Gaussian noise for each complex dim. 77 noise = complex_normal(tf.shape(x), dtype=x.dtype) 78 79 # Add extra dimensions for broadcasting 80 no = expand_to_rank(no, tf.rank(x), axis=-1) 81 82 # Apply variance scaling 83 no = tf.cast(no, self._real_dtype) 84 noise *= tf.cast(tf.sqrt(no), noise.dtype) 85 86 # Add noise to input 87 y = x + noise 88 89 return y