scrambling.py (25638B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers for scrambling, descrambling and utility functions.""" 6 import numpy as np 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from sionna import config 10 from sionna.utils import expand_to_rank 11 from sionna.nr.utils import generate_prng_seq 12 13 14 class Scrambler(Layer): 15 # pylint: disable=line-too-long 16 r"""Scrambler(seed=None, keep_batch_constant=False, sequence=None, binary=True keep_state=True, dtype=tf.float32, **kwargs) 17 18 Randomly flips the state/sign of a sequence of bits or LLRs, respectively. 19 20 The class inherits from the Keras layer class and can be used as layer in a 21 Keras model. 22 23 Parameters 24 ---------- 25 seed: int 26 Defaults to None. Defines the initial state of the 27 pseudo random generator to generate the scrambling sequence. 28 If None, a random integer will be generated. Only used 29 when called with ``keep_state`` is True. 30 31 keep_batch_constant: bool 32 Defaults to False. If True, all samples in the batch are scrambled 33 with the same scrambling sequence. Otherwise, per sample a random 34 sequence is generated. 35 36 sequence: Array of 0s and 1s or None 37 If provided, the seed will be ignored and the explicit scrambling 38 sequence is used. Shape must be broadcastable to ``x``. 39 40 binary: bool 41 Defaults to True. Indicates whether bit-sequence should be flipped 42 (i.e., binary operations are performed) or the signs should be 43 flipped (i.e., soft-value/LLR domain-based). 44 45 keep_state: bool 46 Defaults to True. Indicates whether the scrambling sequence should 47 be kept constant. 48 49 dtype: tf.DType 50 Defaults to `tf.float32`. Defines the datatype for internal 51 calculations and the output dtype. 52 53 Input 54 ----- 55 (x, seed, binary): 56 Either Tuple ``(x, seed, binary)`` or ``(x, seed)`` or ``x`` only 57 (no tuple) if the internal seed should be used: 58 59 x: tf.float 60 1+D tensor of arbitrary shape. 61 seed: int 62 An integer defining the state of the random number 63 generator. If explicitly given, the global internal seed is 64 replaced by this seed. Can be used to realize random 65 scrambler/descrambler pairs (call with same random seed). 66 binary: bool 67 Overrules the init parameter `binary` iff explicitly given. 68 Indicates whether bit-sequence should be flipped 69 (i.e., binary operations are performed) or the signs should be 70 flipped (i.e., soft-value/LLR domain-based). 71 72 Output 73 ------ 74 : tf.float 75 1+D tensor of same shape as ``x``. 76 77 Note 78 ---- 79 For inverse scrambling, the same scrambler can be re-used (as the values 80 are flipped again, i.e., result in the original state). However, 81 ``keep_state`` must be set to True as a new sequence would be generated 82 otherwise. 83 84 The scrambler layer is stateless, i.e., the seed is either random 85 during each call or must be explicitly provided during init/call. 86 This simplifies XLA/graph execution. 87 If the seed is provided in the init() function, this fixed seed is used 88 for all calls. However, an explicit seed can be provided during 89 the call function to realize `true` random states. 90 91 Scrambling is typically used to ensure equal likely `0` and `1` for 92 sources with unequal bit probabilities. As we have a perfect source in 93 the simulations, this is not required. However, for all-zero codeword 94 simulations and higher-order modulation, so-called "channel-adaptation" 95 [Pfister03]_ is required. 96 97 Raises 98 ------ 99 AssertionError 100 If ``seed`` is not int. 101 102 AssertionError 103 If ``keep_batch_constant`` is not bool. 104 105 AssertionError 106 If ``binary`` is not bool. 107 108 AssertionError 109 If ``keep_state`` is not bool. 110 111 AssertionError 112 If ``seed`` is provided to list of inputs but not an 113 int. 114 115 TypeError 116 If `dtype` of ``x`` is not as expected. 117 """ 118 def __init__(self, 119 seed=None, 120 keep_batch_constant=False, 121 binary=True, 122 sequence=None, 123 keep_state=True, 124 dtype=tf.float32, 125 **kwargs): 126 127 if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8, 128 tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32): 129 raise TypeError("Unsupported dtype.") 130 131 super().__init__(dtype=dtype, **kwargs) 132 133 assert isinstance(keep_batch_constant, bool), \ 134 "keep_batch_constant must be bool." 135 self._keep_batch_constant = keep_batch_constant 136 137 if seed is not None: 138 if sequence is not None: 139 print("Note: explicit scrambling sequence provided. " \ 140 "Seed will be ignored.") 141 assert isinstance(seed, int), "seed must be int." 142 else: 143 seed = int(np.random.uniform(0, 2**31-1)) 144 145 assert isinstance(binary, bool), "binary must be bool." 146 assert isinstance(keep_state, bool), "keep_state must be bool." 147 148 self._binary = binary 149 self._keep_state = keep_state 150 151 self._check_input = True 152 153 # if keep_state==True this seed is used to generate scrambling sequences 154 self._seed = (1337, seed) 155 156 # if an explicit sequence is provided the above parameters will be 157 # ignored 158 self._sequence = None 159 if sequence is not None: 160 sequence = tf.cast(sequence, self.dtype) 161 # check that sequence is binary 162 tf.debugging.assert_equal( 163 tf.reduce_min( 164 tf.cast( 165 tf.logical_or( 166 tf.equal(sequence, tf.constant(0, self.dtype)), 167 tf.equal(sequence, tf.constant(1, self.dtype)),), 168 self.dtype)), 169 tf.constant(1, self.dtype), 170 "Scrambling sequence must be binary.") 171 self._sequence = sequence 172 173 174 ######################################### 175 # Public methods and properties 176 ######################################### 177 178 @property 179 def seed(self): 180 """Seed used to generate random sequence.""" 181 return self._seed[1] # only return the non-fixed seed 182 183 @property 184 def keep_state(self): 185 """Indicates if new random sequences are used per call.""" 186 return self._keep_state 187 188 @property 189 def sequence(self): 190 """Explicit scrambling sequence if provided.""" 191 return self._sequence 192 193 ######################### 194 # Utility methods 195 ######################### 196 197 def _generate_scrambling(self, input_shape, seed): 198 r"""Generates a random sequence of `0`s and `1`s that can be used 199 to initialize a scrambler and updates the internal attributes. 200 """ 201 if self._keep_batch_constant: 202 input_shape_no_bs = input_shape[1:] 203 seq = tf.random.stateless_uniform(input_shape_no_bs, 204 seed, 205 minval=0, 206 maxval=2, 207 dtype=tf.int32) 208 # expand batch dim such that it can be broadcasted 209 seq = tf.expand_dims(seq, axis=0) 210 else: 211 seq = tf.random.stateless_uniform(input_shape, 212 seed, 213 minval=0, 214 maxval=2, 215 dtype=tf.int32) 216 217 return tf.cast(seq, self.dtype) # enable flexible dtypes 218 219 ######################### 220 # Keras layer functions 221 ######################### 222 223 def build(self, input_shape): 224 """Build the model and initialize variables.""" 225 pass 226 227 def call(self, inputs): 228 r"""scrambling function. 229 230 This function returns the scrambled version of ``inputs``. 231 232 ``inputs`` can be either a list ``[x, seed]`` or single tensor ``x``. 233 234 Args: 235 inputs (List): ``[x, seed]``, where 236 ``x`` (tf.float32): Tensor of arbitrary shape. 237 ``seed`` (int): An integer defining the state of the random number 238 generator. If explicitly given, the global internal seed is 239 replaced by this seed. Can be used the realize random 240 scrambler/descrambler pairs (call with same random seed). 241 242 Returns: 243 `tf.float32`: Tensor of same shape as the input. 244 245 Raises: 246 AssertionError 247 If ``seed`` is not None or int. 248 249 TypeError 250 If `dtype` of ``x`` is not as expected. 251 """ 252 is_binary = self._binary # can be overwritten if explicitly provided 253 254 if isinstance(inputs, (tuple, list)): 255 if len(inputs)==1: # if user wants to call with call([x]) 256 seed = None 257 x = inputs 258 elif len(inputs)==2: 259 x, seed = inputs 260 elif len(inputs)==3: 261 # allow that is_binary flag can be explicitly provided (descrambler) 262 x, seed, is_binary = inputs 263 # is binary can be either a tensor or bool 264 if isinstance(is_binary, tf.Tensor): 265 if not is_binary.dtype.is_bool: 266 raise TypeError("binary must be bool.") 267 else: # is boolean 268 assert isinstance(is_binary.dtype, bool), \ 269 "binary must be bool." 270 else: 271 raise TypeError("inputs cannot have more than 3 entries.") 272 else: 273 seed = None 274 x = inputs 275 276 tf.debugging.assert_type(x, self.dtype, 277 "Invalid input dtype.") 278 279 input_shape = tf.shape(x) 280 281 # generate random sequence on-the-fly (due to unknown shapes during 282 # compile/build time) 283 # use seed if explicit seed is provided 284 if seed is not None: 285 #assert seed.dtype.is_integer, "seed must be int." 286 seed = (1337, seed) 287 # only generate a new random sequence if keep_state==False 288 elif self._keep_state: 289 # use sequence as defined by seed 290 seed = self._seed 291 else: 292 # generate new seed for each call 293 # Note: not necessarily random if XLA is active 294 seed = config.tf_rng.uniform([2], 295 minval=0, 296 maxval=2**31-1, 297 dtype=tf.int32) 298 299 # apply sequence if explicit sequence is provided 300 if self._sequence is not None: 301 rand_seq = self._sequence 302 else: 303 rand_seq = self._generate_scrambling(input_shape, seed) 304 305 if is_binary: 306 # flip the bits by subtraction and map -1 to 1 via abs(.) operator 307 x_out = tf.abs(x - rand_seq) 308 else: 309 rand_seq_bipol = -2 * rand_seq + 1 310 x_out = tf.multiply(x, rand_seq_bipol) 311 312 return x_out 313 314 class TB5GScrambler(Layer): 315 # pylint: disable=line-too-long 316 r"""TB5GScrambler(n_rnti=1, n_id=1, binary=True, channel_type="PUSCH", codeword_index=0, dtype=tf.float32, **kwargs) 317 318 Implements the pseudo-random bit scrambling as defined in 319 [3GPPTS38211_scr]_ Sec. 6.3.1.1 for the "PUSCH" channel and in Sec. 7.3.1.1 320 for the "PDSCH" channel. 321 322 Only for the "PDSCH" channel, the scrambler can be configured for two 323 codeword transmission mode. Hereby, ``codeword_index`` corresponds to the 324 index of the codeword to be scrambled. 325 326 If ``n_rnti`` are a list of ints, the scrambler assumes that the second 327 last axis contains `len(` ``n_rnti`` `)` elements. This allows independent 328 scrambling for multiple independent streams. 329 330 The class inherits from the Keras layer class and can be used as layer in a 331 Keras model. 332 333 Parameters 334 ---------- 335 n_rnti: int or list of ints 336 RNTI identifier provided by higher layer. Defaults to 1 and must be 337 in range `[0, 65335]`. If a list is provided, every list element 338 defines a scrambling sequence for multiple independent streams. 339 340 n_id: int or list of ints 341 Scrambling ID related to cell id and provided by higher layer. 342 Defaults to 1 and must be in range `[0, 1023]`. If a list is 343 provided, every list element defines a scrambling sequence for 344 multiple independent streams. 345 346 binary: bool 347 Defaults to True. Indicates whether bit-sequence should be flipped 348 (i.e., binary operations are performed) or the signs should be 349 flipped (i.e., soft-value/LLR domain-based). 350 351 channel_type: str 352 Can be either "PUSCH" or "PDSCH". 353 354 codeword_index: int 355 Scrambler can be configured for two codeword transmission. 356 ``codeword_index`` can be either 0 or 1. 357 358 dtype: tf.DType 359 Defaults to `tf.float32`. Defines the datatype for internal 360 calculations and the output dtype. 361 362 Input 363 ----- 364 (x, binary): 365 Either Tuple ``(x, binary)`` or ``x`` only 366 367 x: tf.float 368 1+D tensor of arbitrary shape. If ``n_rnti`` and ``n_id`` are a 369 list, it is assumed that ``x`` has shape 370 `[...,num_streams, n]` where `num_streams=len(` ``n_rnti`` `)`. 371 372 binary: bool 373 Overrules the init parameter `binary` iff explicitly given. 374 Indicates whether bit-sequence should be flipped 375 (i.e., binary operations are performed) or the signs should be 376 flipped (i.e., soft-value/LLR domain-based). 377 378 Output 379 ------ 380 : tf.float 381 1+D tensor of same shape as ``x``. 382 383 Note 384 ---- 385 The parameters radio network temporary identifier (RNTI) ``n_rnti`` and 386 the datascrambling ID ``n_id`` are usually provided be the higher layer protocols. 387 388 For inverse scrambling, the same scrambler can be re-used (as the values 389 are flipped again, i.e., result in the original state). 390 """ 391 def __init__(self, 392 n_rnti=1, 393 n_id=1, 394 binary=True, 395 channel_type="PUSCH", 396 codeword_index=0, 397 dtype=tf.float32, 398 **kwargs): 399 400 if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8, 401 tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32): 402 raise TypeError("Unsupported dtype.") 403 404 super().__init__(dtype=dtype, **kwargs) 405 406 assert isinstance(binary, bool), "binary must be bool." 407 assert channel_type in ("PDSCH", "PUSCH"), "Unsupported channel_type." 408 assert(codeword_index in (0, 1)), "codeword_index must be 0 or 1." 409 410 self._binary = binary 411 self._check_input = True 412 self._input_shape = None 413 414 # allow list input for independent multi-stream scrambling 415 if isinstance(n_rnti, (list, tuple)): 416 assert isinstance(n_id, (list, tuple)), \ 417 "n_id must be a list of same length as n_rnti." 418 419 assert len(n_rnti)==len(n_id), \ 420 "n_rnti and n_id must be of same length." 421 422 self._multi_stream = True 423 else: 424 n_rnti = [n_rnti] 425 n_id = [n_id] 426 self._multi_stream = False 427 428 # check all entries for consistency 429 for idx, (nr, ni) in enumerate(zip(n_rnti, n_id)): 430 # allow floating inputs, but verify that it represent an int value 431 assert(nr%1==0), "n_rnti must be integer." 432 assert nr in range(2**16), "n_rnti must be in [0, 65535]." 433 n_rnti[idx] = int(nr) 434 assert(ni%1==0), "n_rnti must be integer." 435 assert ni in range(2**10), "n_id must be in [0, 1023]." 436 n_id[idx] = int(ni) 437 438 self._c_init = [] 439 if channel_type=="PUSCH": 440 # defined in 6.3.1.1 in 38.211 441 for nr, ni in zip(n_rnti, n_id): 442 self._c_init += [nr * 2**15 + ni] 443 elif channel_type =="PDSCH": 444 # defined in 7.3.1.1 in 38.211 445 for nr, ni in zip(n_rnti, n_id): 446 self._c_init += [nr * 2**15 + codeword_index * 2**14 + ni] 447 448 ######################################### 449 # Public methods and properties 450 ######################################### 451 452 @property 453 def keep_state(self): 454 """Required for descrambler, is always `True` for the TB5GScrambler.""" 455 return True 456 457 ######################### 458 # Utility methods 459 ######################### 460 461 def _generate_scrambling(self, input_shape): 462 r"""Returns random sequence of `0`s and `1`s following 463 [3GPPTS38211_scr]_ .""" 464 465 seq = generate_prng_seq(input_shape[-1], self._c_init[0]) 466 seq = tf.constant(seq, self.dtype) # enable flexible dtypes 467 seq = expand_to_rank(seq, len(input_shape), axis=0) 468 469 if self._multi_stream: 470 for c in self._c_init[1:]: 471 s = generate_prng_seq(input_shape[-1], c) 472 s = tf.constant(s, self.dtype) # enable flexible dtypes 473 s = expand_to_rank(s, len(input_shape), axis=0) 474 seq = tf.concat([seq, s], axis=-2) 475 476 return seq 477 478 ######################### 479 # Keras layer functions 480 ######################### 481 482 def build(self, input_shape): 483 """Initialize pseudo-random scrambling sequence.""" 484 485 # input can be also a list, we are only interested in the shape of x 486 if isinstance(input_shape, (tuple)): 487 if len(input_shape)==1: # if user wants to call with call([x]) 488 input_shape = input_shape(0) 489 elif len(input_shape)==2: 490 # allow that flag binary is explicitly provided (descrambler) 491 input_shape, _ = input_shape 492 self._input_shape = input_shape 493 494 # in multi-stream mode, the axis=-2 must have dimension=len(c_init) 495 if self._multi_stream: 496 assert input_shape[-2]==len(self._c_init), \ 497 "Dimension of axis=-2 must be equal to len(n_rnti)." 498 499 self._sequence = self._generate_scrambling(input_shape) 500 501 def call(self, inputs): 502 r"""This function returns the scrambled version of ``inputs``. 503 """ 504 is_binary = self._binary # can be overwritten if explicitly provided 505 506 if isinstance(inputs, (tuple, list)): 507 if len(inputs)==1: # if user wants to call with call([x]) 508 x, = inputs 509 elif len(inputs)==2: 510 # allow that binary flag is explicitly provided (descrambler) 511 x, is_binary = inputs 512 # is_binary can be either a tensor or bool 513 if isinstance(is_binary, tf.Tensor): 514 if not is_binary.dtype.is_bool: 515 raise TypeError("binary must be bool.") 516 else: # is boolean 517 assert isinstance(is_binary.dtype, bool), \ 518 "binary must be bool." 519 else: 520 raise TypeError("inputs cannot have more than 3 entries.") 521 else: 522 x = inputs 523 524 if not x.shape[-1]==self._input_shape: 525 self.build((x.shape)) 526 527 if is_binary: 528 # flip the bits by subtraction and map -1 to 1 via abs(.) operator 529 x_out = tf.abs(x - self._sequence) 530 else: 531 rand_seq_bipol = -2 * self._sequence + 1 532 x_out = tf.multiply(x, rand_seq_bipol) 533 534 return x_out 535 536 class Descrambler(Layer): 537 r"""Descrambler(scrambler, binary=True, dtype=None, **kwargs) 538 539 Descrambler for a given scrambler. 540 541 The class inherits from the Keras layer class and can be used as layer in a 542 Keras model. 543 544 Parameters 545 ---------- 546 scrambler: Scrambler, TB5GScrambler 547 Associated :class:`~sionna.fec.scrambling.Scrambler` or 548 :class:`~sionna.fec.scrambling.TB5GScrambler` instance which 549 should be descrambled. 550 551 binary: bool 552 Defaults to True. Indicates whether bit-sequence should be flipped 553 (i.e., binary operations are performed) or the signs should be 554 flipped (i.e., soft-value/LLR domain-based). 555 556 dtype: None or tf.DType 557 Defaults to `None`. Defines the datatype for internal calculations 558 and the output dtype. If no explicit dtype is provided the dtype 559 from the associated interleaver is used. 560 561 Input 562 ----- 563 (x, seed): 564 Either Tuple ``(x, seed)`` or ``x`` only (no tuple) if the internal 565 seed should be used: 566 567 x: tf.float 568 1+D tensor of arbitrary shape. 569 570 seed: int 571 An integer defining the state of the random number 572 generator. If explicitly given, the global internal seed is 573 replaced by this seed. Can be used to realize random 574 scrambler/descrambler pairs (call with same random seed). 575 576 Output 577 ------ 578 : tf.float 579 1+D tensor of same shape as ``x``. 580 581 Raise 582 ----- 583 AssertionError 584 If ``scrambler`` is not an instance of `Scrambler`. 585 586 AssertionError 587 If ``seed`` is provided to list of inputs but not an 588 int. 589 590 TypeError 591 If `dtype` of ``x`` is not as expected. 592 """ 593 def __init__(self, 594 scrambler, 595 binary=True, 596 dtype=None, 597 **kwargs): 598 599 assert isinstance(scrambler, (Scrambler, TB5GScrambler)), \ 600 "scrambler must be an instance of Scrambler." 601 self._scrambler = scrambler 602 603 assert isinstance(binary, bool), "binary must be bool." 604 self._binary = binary 605 606 # if dtype is None, use same dtype as associated scrambler 607 if dtype is None: 608 dtype = self._scrambler.dtype 609 610 super().__init__(dtype=dtype, **kwargs) 611 612 if self._scrambler.keep_state is False: 613 print("Warning: scrambler uses random sequences that cannot be " \ 614 "access by descrambler. Please use keep_state=True and " \ 615 "provide explicit random seed as input to call function.") 616 617 if self._scrambler.dtype != self.dtype: 618 print("Scrambler and descrambler are using different " \ 619 "dtypes. This will cause an internal implicit cast.") 620 621 ######################################### 622 # Public methods and properties 623 ######################################### 624 625 @property 626 def scrambler(self): 627 """Associated scrambler instance.""" 628 return self._scrambler 629 630 ######################### 631 # Utility methods 632 ######################### 633 634 ######################### 635 # Keras layer functions 636 ######################### 637 638 def build(self, input_shape): 639 """Build the model and initialize variables.""" 640 pass 641 642 def call(self, inputs): 643 r"""Descrambling function. 644 645 This function returns the descrambled version of ``inputs``. 646 647 ``inputs`` can be either a list ``[x, seed]`` or single tensor ``x``. 648 649 Args: 650 inputs (List): ``[x, seed]``, where 651 ``x`` (tf.float32): Tensor of arbitrary shape. 652 ``seed`` (int): An integer defining the state of the random number 653 generator. If not explicitly given, the global internal seed is 654 replaced by this seed. Can be used the realize random 655 scrambler/descrambler pairs (must be called with same random 656 seed). 657 658 Returns: 659 `tf.float32`: Tensor of same shape as the input. 660 661 Raises: 662 AssertionError: If ``seed`` is not `None` or `int`. 663 """ 664 665 # Scrambler 666 if isinstance(self._scrambler, Scrambler): 667 if isinstance(inputs, (tuple, list)): 668 if len(inputs)>2: 669 raise TypeError("inputs cannot have more than 2 entries.") 670 else: # seed explicitly given 671 inputs.append(self._binary) 672 else: # seed not given 673 s = self._scrambler.seed # use seed from associated scrambler 674 inputs = (inputs, s, self._binary) 675 elif isinstance(self._scrambler, TB5GScrambler): 676 if isinstance(inputs, (tuple, list)): 677 if len(inputs)>1: 678 raise TypeError("inputs cannot have more than 1 entries.") 679 else: # seed explicitly given 680 inputs.append(self._binary) 681 else: # not list as input 682 inputs = (inputs, self._binary) 683 else: 684 raise TypeError("Unknown Scrambler type.") 685 686 x_out = self._scrambler(inputs) 687 688 # scrambler could potentially have different dtypes 689 return tf.cast(x_out, super().dtype)