config.py (4842B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Global Sionna configuration""" 6 import random 7 import numpy as np 8 import tensorflow as tf 9 10 class Config(): 11 """Sionna configuration class 12 13 This singleton class is used to define global configuration variables 14 and random number generators that can be accessed from all modules 15 and functions. It is instantiated immediately and its properties can be 16 accessed as "sionna.config.desired_property". 17 """ 18 19 # This object is a singleton 20 _instance = None 21 def __new__(cls): 22 if cls._instance is None: 23 instance = object.__new__(cls) 24 cls._instance = instance 25 return cls._instance 26 27 def __init__(self): 28 self._seed = None 29 self._py_rng = None 30 self._np_rng = None 31 self._tf_rng = None 32 self._xla_compat = None 33 34 # Set default properties 35 self.xla_compat = False 36 37 @property 38 def py_rng(self): 39 """ 40 random.Random() : Python random number generator 41 42 .. code-block:: python 43 44 import sionna 45 sionna.config.seed = 42 # Set seed for deterministic results 46 47 # Use generator instead of random 48 int = sionna.config.py_rng.randint(0, 10) 49 """ 50 if self._py_rng is None: 51 self._py_rng = random.Random() 52 return self._py_rng 53 54 @property 55 def np_rng(self): 56 """ 57 np.random.Generator : NumPy random number generator 58 59 .. code-block:: python 60 61 import sionna 62 sionna.config.seed = 42 # Set seed for deterministic results 63 64 # Use generator instead of np.random 65 noise = sionna.config.np_rng.normal(size=[4]) 66 """ 67 if self._np_rng is None: 68 self._np_rng = np.random.default_rng() 69 return self._np_rng 70 71 @property 72 def tf_rng(self): 73 """ 74 tf.random.Generator : TensorFlow random number generator 75 76 .. code-block:: python 77 78 import sionna 79 sionna.config.seed = 42 # Set seed for deterministic results 80 81 # Use generator instead of tf.random 82 noise = sionna.config.tf_rng.normal([4]) 83 """ 84 if self._tf_rng is None: 85 self._tf_rng = tf.random.Generator.from_non_deterministic_state() 86 return self._tf_rng 87 88 @property 89 def seed(self): 90 # pylint: disable=line-too-long 91 """Get/set seed for all random number generators 92 93 All random number generators used internally by Sionna 94 can be configured with a common seed to ensure reproducability 95 of results. It defaults to `None` which implies that a random 96 seed will be used and results are non-deterministic. 97 98 .. code-block:: python 99 100 # This code will lead to deterministic results 101 import sionna 102 sionna.config.seed = 42 103 print(sionna.utils.BinarySource()([10])) 104 105 .. code-block:: console 106 107 tf.Tensor([0. 1. 1. 1. 1. 0. 1. 0. 1. 0.], shape=(10,), dtype=float32) 108 109 :type: int 110 """ 111 return self._seed 112 113 @seed.setter 114 def seed(self, seed): 115 # Store seed 116 if seed is not None: 117 seed = int(seed) 118 self._seed = seed 119 120 #TensorFlow 121 self.tf_rng.reset_from_seed(seed) 122 123 # Python 124 self.py_rng.seed(seed) 125 126 # NumPy 127 self._np_rng = np.random.default_rng(seed) 128 129 @property 130 def xla_compat(self): 131 """Ensure that functions execute in an XLA compatible way. 132 133 Not all TensorFlow ops support the three execution modes for 134 all dtypes: Eager, Graph, and Graph with XLA. For this reason, 135 some functions are implemented differently depending on the 136 execution mode. As it is currently impossible to programmatically 137 determine if a function is executed in Graph or Graph with XLA mode, 138 the ``xla_compat`` property can be used to indicate which execution 139 mode is desired. Note that most functions will work in all execution 140 modes independently of the value of this property. 141 142 This property can be used like this: 143 144 .. code-block:: python 145 146 import sionna 147 sionna.config.xla_compat=True 148 @tf.function(jit_compile=True) 149 def func() 150 # Implementation 151 152 func() 153 154 :type: bool 155 """ 156 return self._xla_compat 157 158 @xla_compat.setter 159 def xla_compat(self, value): 160 self._xla_compat = bool(value) 161 if self._xla_compat: 162 msg = "XLA can lead to reduced numerical precision." \ 163 + " Use with care." 164 print(msg) 165 166 config = Config()