decoding.py (90204B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers for Polar decoding such as successive cancellation (SC), successive 6 cancellation list (SCL) and iterative belief propagation (BP) decoding.""" 7 8 import tensorflow as tf 9 import numpy as np 10 from numpy.core.numerictypes import issubdtype 11 import warnings 12 from tensorflow.keras.layers import Layer 13 from sionna.fec.crc import CRCDecoder, CRCEncoder 14 from sionna.fec.polar.encoding import Polar5GEncoder 15 import numbers 16 17 class PolarSCDecoder(Layer): 18 """PolarSCDecoder(frozen_pos, n, output_dtype=tf.float32, **kwargs) 19 20 Successive cancellation (SC) decoder [Arikan_Polar]_ for Polar codes and 21 Polar-like codes. 22 23 The class inherits from the Keras layer class and can be used as layer in a 24 Keras model. 25 26 Parameters 27 ---------- 28 frozen_pos: ndarray 29 Array of `int` defining the ``n-k`` indices of the frozen positions. 30 31 n: int 32 Defining the codeword length. 33 34 output_dtype: tf.DType 35 Defaults to tf.float32. Defines the output datatype of the layer 36 (internal precision remains tf.float32). 37 38 Input 39 ----- 40 inputs: [...,n], tf.float32 41 2+D tensor containing the channel LLR values (as logits). 42 43 Output 44 ------ 45 : [...,k], tf.float32 46 2+D tensor containing hard-decided estimations of all ``k`` 47 information bits. 48 49 Raises 50 ------ 51 AssertionError 52 If ``n`` is not `int`. 53 54 AssertionError 55 If ``n`` is not a power of 2. 56 57 AssertionError 58 If the number of elements in ``frozen_pos`` is greater than ``n``. 59 60 AssertionError 61 If ``frozen_pos`` does not consists of `int`. 62 63 ValueError 64 If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}. 65 66 Note 67 ---- 68 This layer implements the SC decoder as described in 69 [Arikan_Polar]_. However, the implementation follows the `recursive 70 tree` [Gross_Fast_SCL]_ terminology and combines nodes for increased 71 throughputs without changing the outcome of the algorithm. 72 73 As commonly done, we assume frozen bits are set to `0`. Please note 74 that - although its practical relevance is only little - setting frozen 75 bits to `1` may result in `affine` codes instead of linear code as the 76 `all-zero` codeword is not necessarily part of the code any more. 77 78 """ 79 80 def __init__(self, frozen_pos, n, output_dtype=tf.float32, **kwargs): 81 82 if output_dtype not in (tf.float16, tf.float32, tf.float64): 83 raise ValueError( 84 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 85 86 if output_dtype is not tf.float32: 87 print('Note: decoder uses tf.float32 for internal calculations.') 88 89 super().__init__(dtype=output_dtype, **kwargs) 90 self._output_dtype = output_dtype 91 92 # assert error if r>1 or k, n are negativ 93 assert isinstance(n, numbers.Number), "n must be a number." 94 n = int(n) # n can be float (e.g. as result of n=k*r) 95 96 assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int." 97 assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \ 98 "be greater than n." 99 assert np.log2(n)==int(np.log2(n)), "n must be a power of 2." 100 101 # store internal attributes 102 self._n = n 103 self._frozen_pos = frozen_pos 104 self._k = self._n - len(self._frozen_pos) 105 self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos) 106 assert self._k==len(self._info_pos), "Internal error: invalid " \ 107 "info_pos generated." 108 self._llr_max = 30. # internal max LLR value (uncritical for SC dec) 109 # and create a frozen bit vector for simpler encoding 110 self._frozen_ind = np.zeros(self._n) 111 self._frozen_ind[self._frozen_pos] = 1 112 113 # enable graph pruning 114 self._use_fast_sc = False 115 116 ######################################### 117 # Public methods and properties 118 ######################################### 119 120 @property 121 def n(self): 122 """Codeword length.""" 123 return self._n 124 125 @property 126 def k(self): 127 """Number of information bits.""" 128 return self._k 129 130 @property 131 def frozen_pos(self): 132 """Frozen positions for Polar decoding.""" 133 return self._frozen_pos 134 135 @property 136 def info_pos(self): 137 """Information bit positions for Polar encoding.""" 138 return self._info_pos 139 140 @property 141 def llr_max(self): 142 """Maximum LLR value for internal calculations.""" 143 return self._llr_max 144 145 @property 146 def output_dtype(self): 147 """Output dtype of decoder.""" 148 return self._output_dtype 149 150 ######################### 151 # Utility methods 152 ######################### 153 154 def _cn_op_tf(self, x, y): 155 """Check-node update (boxplus) for LLR inputs. 156 157 Operations are performed element-wise. 158 159 See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. 160 """ 161 x_in = tf.clip_by_value(x, 162 clip_value_min=-self._llr_max, 163 clip_value_max=self._llr_max) 164 y_in = tf.clip_by_value(y, 165 clip_value_min=-self._llr_max, 166 clip_value_max=self._llr_max) 167 168 # avoid division for numerical stability 169 llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in)) 170 llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in)) 171 172 return llr_out 173 174 def _vn_op_tf(self, x, y, u_hat): 175 """VN update for LLR inputs.""" 176 return tf.multiply((1-2*u_hat), x) + y 177 178 def _polar_decode_sc_tf(self, llr_ch, frozen_ind): 179 """Recursive SC decoding function. 180 181 Recursively branch decoding tree and split into decoding of `upper` 182 and `lower` path until reaching a leaf node. 183 184 The function returns the u_hat decisions at stage `0` and the bit 185 decisions of the intermediate stage `s` (i.e., the re-encoded version of 186 `u_hat` until the current stage `s`). 187 188 Note: 189 This decoder parallelizes over the batch-dimension, i.e., the tree 190 is processed for all samples in the batch in parallel. This yields a 191 higher throughput, but does not improve the latency. 192 """ 193 194 # calculate current codeword length 195 n = len(frozen_ind) 196 197 # branch if leaf is not reached yet 198 if n>1: 199 if self._use_fast_sc: 200 if np.sum(frozen_ind)==n: 201 #print("rate-0 detected! Length: ", n) 202 u_hat = tf.zeros_like(llr_ch) 203 return u_hat, u_hat 204 205 llr_ch1 = llr_ch[...,0:int(n/2)] 206 llr_ch2 = llr_ch[...,int(n/2):] 207 frozen_ind1 = frozen_ind[0:int(n/2)] 208 frozen_ind2 = frozen_ind[int(n/2):] 209 210 # upper path 211 x_llr1_in = self._cn_op_tf(llr_ch1, llr_ch2) 212 213 # and call the decoding function (with upper half) 214 u_hat1, u_hat1_up = self._polar_decode_sc_tf(x_llr1_in, frozen_ind1) 215 216 # lower path 217 x_llr2_in = self._vn_op_tf(llr_ch1, llr_ch2, u_hat1_up) 218 # and call the decoding function again (with lower half) 219 u_hat2, u_hat2_up = self._polar_decode_sc_tf(x_llr2_in, frozen_ind2) 220 221 # combine u_hat from both branches 222 u_hat = tf.concat([u_hat1, u_hat2], -1) 223 224 # calculate re-encoded version of u_hat at current stage 225 # u_hat1_up = tf.math.mod(u_hat1_up + u_hat2_up, 2) 226 # combine u_hat via bitwise_xor (more efficient than mod2) 227 u_hat1_up_int = tf.cast(u_hat1_up, tf.int8) 228 u_hat2_up_int = tf.cast(u_hat2_up, tf.int8) 229 u_hat1_up_int = tf.bitwise.bitwise_xor(u_hat1_up_int, 230 u_hat2_up_int) 231 u_hat1_up = tf.cast(u_hat1_up_int , tf.float32) 232 u_hat_up = tf.concat([u_hat1_up, u_hat2_up], -1) 233 234 else: # if leaf is reached perform basic decoding op (=decision) 235 236 if frozen_ind==1: # position is frozen 237 u_hat = tf.expand_dims(tf.zeros_like(llr_ch[:,0]), axis=-1) 238 u_hat_up = u_hat 239 else: # otherwise hard decide 240 u_hat = 0.5 * (1. - tf.sign(llr_ch)) 241 #remove "exact 0 llrs" leading to u_hat=0.5 242 u_hat = tf.where(tf.equal(u_hat, 0.5), 243 tf.ones_like(u_hat), 244 u_hat) 245 u_hat_up = u_hat 246 return u_hat, u_hat_up 247 248 ######################### 249 # Keras layer functions 250 ######################### 251 252 def build(self, input_shape): 253 """Check if shape of input is invalid.""" 254 assert (input_shape[-1]==self._n), "Invalid input shape." 255 assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.' 256 257 def call(self, inputs): 258 """Successive cancellation (SC) decoding function. 259 260 Performs successive cancellation decoding and returns the estimated 261 information bits. 262 263 Args: 264 inputs (tf.float32): Tensor of shape `[...,n]` containing the 265 channel LLR values (as logits). 266 267 Returns: 268 `tf.float32`: Tensor of shape `[...,k]` containing 269 hard-decided estimations of all ``k`` information bits. 270 271 Raises: 272 ValueError: If ``inputs`` is not of shape `[..., n]` 273 or `dtype` is not `tf.float32`. 274 275 InvalidArgumentError: When rank(``inputs``)<2. 276 277 Note: 278 This function recursively unrolls the SC decoding tree, thus, 279 for larger values of ``n`` building the decoding graph can become 280 time consuming. 281 """ 282 283 tf.debugging.assert_type(inputs, self.dtype, 'Invalid input dtype.') 284 # internal calculations still in tf.float32 285 inputs = tf.cast(inputs, tf.float32) 286 287 # last dim must be of length n 288 tf.debugging.assert_equal(tf.shape(inputs)[-1], 289 self._n, 290 "Last input dimension must be of length n.") 291 292 # Reshape inputs to [-1, n] 293 tf.debugging.assert_greater(tf.rank(inputs), 1) 294 input_shape = inputs.shape 295 new_shape = [-1, self._n] 296 llr_ch = tf.reshape(inputs, new_shape) 297 298 llr_ch = -1. * llr_ch # logits are converted into "true" llrs 299 300 # and decode 301 u_hat_n, _ = self._polar_decode_sc_tf(llr_ch, self._frozen_ind) 302 303 # and recover the k information bit positions 304 u_hat = tf.gather(u_hat_n, self._info_pos, axis=1) 305 306 # and reconstruct input shape 307 output_shape = input_shape.as_list() 308 output_shape[-1] = self.k 309 output_shape[0] = -1 # first dim can be dynamic (None) 310 u_hat_reshape = tf.reshape(u_hat, output_shape) 311 return tf.cast(u_hat_reshape, self._output_dtype) 312 313 class PolarSCLDecoder(Layer): 314 # pylint: disable=line-too-long 315 """PolarSCLDecoder(frozen_pos, n, list_size=8, crc_degree=None, use_hybrid_sc=False, use_fast_scl=True, cpu_only=False, use_scatter=False, ind_iil_inv=None, return_crc_status=False, output_dtype=tf.float32, **kwargs) 316 317 Successive cancellation list (SCL) decoder [Tal_SCL]_ for Polar codes 318 and Polar-like codes. 319 320 The class inherits from the Keras layer class and can be used as layer in a 321 Keras model. 322 323 Parameters 324 ---------- 325 frozen_pos: ndarray 326 Array of `int` defining the ``n-k`` indices of the frozen positions. 327 328 n: int 329 Defining the codeword length. 330 331 list_size: int 332 Defaults to 8. Defines the list size of the decoder. 333 334 crc_degree: str 335 Defining the CRC polynomial to be used. Can be any value from 336 `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`. 337 338 use_hybrid_sc: bool 339 Defaults to False. If True, SC decoding is applied and only the 340 codewords with invalid CRC are decoded with SCL. This option 341 requires an outer CRC specified via ``crc_degree``. 342 Remark: hybrid_sc does not support XLA optimization, i.e., 343 `@tf.function(jit_compile=True)`. 344 345 use_fast_scl: bool 346 Defaults to True. If True, Tree pruning is used to 347 reduce the decoding complexity. The output is equivalent to the 348 non-pruned version (besides numerical differences). 349 350 cpu_only: bool 351 Defaults to False. If True, `tf.py_function` embedding 352 is used and the decoder runs on the CPU. This option is usually 353 slower, but also more memory efficient and, in particular, 354 recommended for larger blocklengths. Remark: cpu_only does not 355 support XLA optimization `@tf.function(jit_compile=True)`. 356 357 use_scatter: bool 358 Defaults to False. If True, `tf.tensor_scatter_update` is used for 359 tensor updates. This option is usually slower, but more memory 360 efficient. 361 362 ind_iil_inv : None or [k+k_crc], int or tf.int 363 Defaults to None. If not `None`, the sequence is used as inverse 364 input bit interleaver before evaluating the CRC. 365 Remark: this only effects the CRC evaluation but the output 366 sequence is not permuted. 367 368 return_crc_status: bool 369 Defaults to False. If True, the decoder additionally returns the 370 CRC status indicating if a codeword was (most likely) correctly 371 recovered. This is only available if ``crc_degree`` is not None. 372 373 output_dtype: tf.DType 374 Defaults to tf.float32. Defines the output datatype of the layer 375 (internal precision remains tf.float32). 376 377 Input 378 ----- 379 inputs: [...,n], tf.float32 380 2+D tensor containing the channel LLR values (as logits). 381 382 Output 383 ------ 384 b_hat : [...,k], tf.float32 385 2+D tensor containing hard-decided estimations of all `k` 386 information bits. 387 388 crc_status : [...], tf.bool 389 CRC status indicating if a codeword was (most likely) correctly 390 recovered. This is only returned if ``return_crc_status`` is True. 391 Note that false positives are possible. 392 393 Raises: 394 AssertionError 395 If ``n`` is not `int`. 396 397 AssertionError 398 If ``n`` is not a power of 2. 399 400 AssertionError 401 If the number of elements in ``frozen_pos`` is greater than ``n``. 402 403 AssertionError 404 If ``frozen_pos`` does not consists of `int`. 405 406 AssertionError 407 If ``list_size`` is not `int`. 408 409 AssertionError 410 If ``cpu_only`` is not `bool`. 411 412 AssertionError 413 If ``use_scatter`` is not `bool`. 414 415 AssertionError 416 If ``use_fast_scl`` is not `bool`. 417 418 AssertionError 419 If ``use_hybrid_sc`` is not `bool`. 420 421 AssertionError 422 If ``list_size`` is not a power of 2. 423 424 ValueError 425 If ``output_dtype`` is not {tf.float16, tf.float32, tf. 426 float64}. 427 428 ValueError 429 If ``inputs`` is not of shape `[..., n]` or `dtype` is not 430 correct. 431 432 InvalidArgumentError 433 When rank(``inputs``)<2. 434 435 Note 436 ---- 437 This layer implements the successive cancellation list (SCL) decoder 438 as described in [Tal_SCL]_ but uses LLR-based message updates 439 [Stimming_LLR]_. The implementation follows the notation from 440 [Gross_Fast_SCL]_, [Hashemi_SSCL]_. If option `use_fast_scl` is active 441 tree pruning is used and tree nodes are combined if possible (see 442 [Hashemi_SSCL]_ for details). 443 444 Implementing SCL decoding as TensorFlow graph is a difficult task that 445 requires several design tradeoffs to match the TF constraints while 446 maintaining a reasonable throughput. Thus, the decoder minimizes 447 the `control flow` as much as possible, leading to a strong memory 448 occupation (e.g., due to full path duplication after each decision). 449 For longer code lengths, the complexity of the decoding graph becomes 450 large and we recommend to use the `CPU_only` option that uses an 451 embedded Numpy decoder. Further, this function recursively unrolls the 452 SCL decoding tree, thus, for larger values of ``n`` building the 453 decoding graph can become time consuming. Please consider the 454 ``cpu_only`` option if building the graph takes to long. 455 456 A hybrid SC/SCL decoder as proposed in [Cammerer_Hybrid_SCL]_ (using SC 457 instead of BP) can be activated with option ``use_hybrid_sc`` iff an 458 outer CRC is available. Please note that the results are not exactly 459 SCL performance caused by the false positive rate of the CRC. 460 461 As commonly done, we assume frozen bits are set to `0`. Please note 462 that - although its practical relevance is only little - setting frozen 463 bits to `1` may result in `affine` codes instead of linear code as the 464 `all-zero` codeword is not necessarily part of the code any more. 465 """ 466 467 def __init__(self, 468 frozen_pos, 469 n, 470 list_size=8, 471 crc_degree=None, 472 use_hybrid_sc=False, 473 use_fast_scl=True, 474 cpu_only=False, 475 use_scatter=False, 476 ind_iil_inv=None, 477 return_crc_status=False, 478 output_dtype=tf.float32, 479 **kwargs): 480 481 if output_dtype not in (tf.float16, tf.float32, tf.float64): 482 raise ValueError( 483 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 484 485 if output_dtype is not tf.float32: 486 print('Note: decoder uses tf.float32 for internal calculations.') 487 488 super().__init__(dtype=output_dtype, **kwargs) 489 self._output_dtype = output_dtype 490 491 # assert error if r>1 or k, n are negative 492 assert isinstance(n, numbers.Number), "n must be a number." 493 n = int(n) # n can be float (e.g. as result of n=k*r) 494 assert isinstance(list_size, int), "list_size must be integer." 495 assert isinstance(cpu_only, bool), "cpu_only must be bool." 496 assert isinstance(use_scatter, bool), "use_scatter must be bool." 497 assert isinstance(use_fast_scl, bool), "use_fast_scl must be bool." 498 assert isinstance(use_hybrid_sc, bool), "use_hybrid_sc must be bool." 499 assert isinstance(return_crc_status, bool), \ 500 "return_crc_status must be bool." 501 502 assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int." 503 assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \ 504 "be greater than n." 505 assert np.log2(n)==int(np.log2(n)), "n must be a power of 2." 506 assert np.log2(list_size)==int(np.log2(list_size)), \ 507 "list_size must be a power of 2." 508 509 # CPU mode is recommended for larger values of n 510 if n>128 and cpu_only is False and use_hybrid_sc is False: 511 warnings.warn("Required resource allocation is large " \ 512 "for the selected blocklength. Consider option `cpu_only=True`.") 513 514 # CPU mode is recommended for larger values of L 515 if list_size>32 and cpu_only is False and use_hybrid_sc is False: 516 warnings.warn("Resource allocation is high for the " \ 517 "selected list_size. Consider option `cpu_only=True`.") 518 519 # internal decoder parameters 520 self._use_fast_scl = use_fast_scl # optimize rate-0 and rep nodes 521 self._use_scatter = use_scatter # slower but more memory friendly 522 self._cpu_only = cpu_only # run numpy decoder 523 self._use_hybrid_sc = use_hybrid_sc 524 525 # store internal attributes 526 self._n = n 527 self._frozen_pos = frozen_pos 528 self._k = self._n - len(self._frozen_pos) 529 self._list_size = list_size 530 self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos) 531 self._llr_max = 30. # internal max LLR value (not very critical for SC) 532 assert self._k==len(self._info_pos), "Internal error: invalid " \ 533 "info_pos generated." 534 # create a frozen bit vector 535 self._frozen_ind = np.zeros(self._n) 536 self._frozen_ind[self._frozen_pos] = 1 537 self._cw_ind = np.arange(self._n) 538 self._n_stages = int(np.log2(self._n)) # number of decoding stages 539 540 # init CRC check (if needed) 541 if crc_degree is not None: 542 self._use_crc = True 543 self._crc_decoder = CRCDecoder(CRCEncoder(crc_degree)) 544 self._k_crc = self._crc_decoder.encoder.crc_length 545 else: 546 self._use_crc = False 547 self._k_crc = 0 548 assert self._k>=self._k_crc, "Value of k is too small for \ 549 given CRC_degree." 550 551 552 if (crc_degree is None) and return_crc_status: 553 self._return_crc_status = False 554 raise ValueError("Returning CRC status requires given crc_degree.") 555 else: 556 self._return_crc_status = return_crc_status 557 558 559 # store the inverse interleaver patter 560 if ind_iil_inv is not None: 561 assert (ind_iil_inv.shape[0]==self._k), \ 562 "ind_int must be of length k+k_crc." 563 self._ind_iil_inv = ind_iil_inv 564 self._iil = True 565 else: 566 self._iil = False 567 568 # use SC decoder first and use numpy-based SCL as "afterburner" 569 if self._use_hybrid_sc: 570 self._decoder_sc = PolarSCDecoder(frozen_pos, n) 571 # Note: CRC required to detect SC success 572 if not self._use_crc: 573 raise ValueError("Hybrid SC requires outer CRC.") 574 575 ######################################### 576 # Public methods and properties 577 ######################################### 578 579 @property 580 def n(self): 581 """Codeword length.""" 582 return self._n 583 584 @property 585 def k(self): 586 """Number of information bits.""" 587 return self._k 588 589 @property 590 def k_crc(self): 591 """Number of CRC bits.""" 592 return self._k_crc 593 594 @property 595 def frozen_pos(self): 596 """Frozen positions for Polar decoding.""" 597 return self._frozen_pos 598 599 @property 600 def info_pos(self): 601 """Information bit positions for Polar encoding.""" 602 return self._info_pos 603 604 @property 605 def llr_max(self): 606 """Maximum LLR value for internal calculations.""" 607 return self._llr_max 608 609 @property 610 def list_size(self): 611 """List size for SCL decoding.""" 612 return self._list_size 613 614 @property 615 def output_dtype(self): 616 """Output dtype of decoder.""" 617 return self._output_dtype 618 619 ##################################### 620 # Helper functions for the TF decoder 621 ##################################### 622 623 def _update_rate0_code(self, msg_pm, msg_uhat, msg_llr, cw_ind): 624 """Update rate-0 sub-code (i.e., all frozen) at pos ``cw_ind``. 625 626 See eq. (26) in [Hashemi_SSCL]_. 627 628 Remark: bits are not explicitly set to `0` as ``msg_uhat`` is 629 initialized with `0` already. 630 """ 631 n = len(cw_ind) 632 stage_ind = int(np.log2(n)) 633 634 llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2) 635 llr_in = tf.clip_by_value(llr, 636 clip_value_min=-self._llr_max, 637 clip_value_max=self._llr_max) 638 639 # update path metric for complete sub-block of length n 640 pm_val = tf.math.softplus(-1.*llr_in) 641 msg_pm += tf.reduce_sum(pm_val, axis=-1) 642 643 return msg_pm, msg_uhat, msg_llr 644 645 def _update_rep_code(self, msg_pm, msg_uhat, msg_llr, cw_ind): 646 """Update rep. code (i.e., only rightmost bit is non-frozen) 647 sub-code at position ``ind_u``. 648 649 See Eq. (31) in [Hashemi_SSCL]_. 650 651 Remark: bits are not explicitly set to `0` as ``msg_uhat`` is 652 initialized with `0` already. 653 """ 654 n = len(cw_ind) 655 stage_ind = int(np.log2(n)) 656 657 # update PM 658 llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2) 659 llr_in = tf.clip_by_value(llr, 660 clip_value_min=-self._llr_max, 661 clip_value_max=self._llr_max) 662 663 # upper branch has negative llr values (bit is 1) 664 llr_low = llr_in[:, :self._list_size, :] 665 llr_up = - llr_in[:, self._list_size:, :] 666 llr_pm = tf.concat([llr_low, llr_up], 1) 667 pm_val = tf.math.softplus(-1.*llr_pm) 668 msg_pm += tf.reduce_sum(pm_val, axis=-1) 669 670 msg_uhat1 = msg_uhat[:, :self._list_size, :, :] 671 msg_uhat21 = tf.expand_dims( 672 msg_uhat[:, self._list_size:, stage_ind, :cw_ind[0]], 673 axis=2) 674 675 msg_uhat22= tf.expand_dims( 676 msg_uhat[:, self._list_size:, stage_ind, cw_ind[-1]+1:], 677 axis=2) 678 # ones to insert 679 msg_ones = tf.ones([tf.shape(msg_uhat)[0], self._list_size, 1, n], 680 tf.float32) 681 682 msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3) 683 msg_uhat24_1 = msg_uhat[:, self._list_size:, :stage_ind, :] 684 msg_uhat24_2 = msg_uhat[:, self._list_size:, stage_ind+1:, :] 685 686 msg_uhat2 = tf.concat([msg_uhat24_1, msg_uhat23, msg_uhat24_2], 2) 687 msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1) 688 689 # branch last bit and update pm at pos cw_ind[-1] 690 msg_uhat = self._update_single_bit([cw_ind[-1]], msg_uhat) 691 msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm, 692 msg_uhat, 693 msg_llr) 694 msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat, 695 msg_llr, 696 msg_pm) 697 return msg_pm, msg_uhat, msg_llr 698 699 def _update_single_bit(self, ind_u, msg_uhat): 700 """Update single bit at position ``ind_u`` for all decoders. 701 702 Remark: bits are not explicitly set to `0` as ``msg_uhat`` is 703 initialized with `0` already. 704 705 Remark: Two versions are implemented (throughput vs. graph complexity): 706 1.) use tensor_scatter_nd_update 707 2.) explicitly split graph and concatenate again 708 """ 709 # position is non-frozen 710 if self._frozen_ind[ind_u[0]]==0: 711 712 # msg_uhat[:, ind_up, 0, ind_u] = 1 713 if self._use_scatter: 714 ind_dec = np.arange(self._list_size, 2*self._list_size, 1) 715 ind_stage = np.array([0]) 716 717 # transpose such that batch dim can be broadcasted 718 msg_uhat_t = tf.transpose(msg_uhat, [1, 3, 2, 0]) 719 720 # generate index grid 721 ind_u = tf.cast(ind_u, tf.int64) 722 grid = tf.meshgrid(ind_dec, ind_u, ind_stage) 723 ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 3]) 724 725 updates = tf.ones([ind.shape[0], tf.shape(msg_uhat)[0]]) 726 msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t, 727 ind, 728 updates) 729 # and restore original order 730 msg_uhat = tf.transpose(msg_uhat_s, [3, 0, 2, 1]) 731 else: 732 # alternative solution with split/concatenation of graph 733 msg_uhat1 = msg_uhat[:, :self._list_size, :, :] 734 msg_uhat21 = tf.expand_dims( 735 msg_uhat[:, self._list_size:, 0, :ind_u[0]], 736 axis=2) 737 738 msg_uhat22= tf.expand_dims( 739 msg_uhat[:, self._list_size:, 0, ind_u[0]+1:], 740 axis=2) 741 # ones to insert 742 msg_ones = tf.ones_like(tf.reshape( 743 msg_uhat[:, self._list_size:, 0, ind_u[0]], 744 [-1, self._list_size, 1, 1])) 745 746 msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3) 747 msg_uhat24 = msg_uhat[:, self._list_size:, 1:, :] 748 749 msg_uhat2 = tf.concat([msg_uhat23, msg_uhat24], 2) 750 msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1) 751 752 return msg_uhat 753 754 def _update_pm(self, ind_u, msg_uhat, msg_llr, msg_pm): 755 """Update path metric of all decoders after updating bit_pos ``ind_u``. 756 757 We implement (10) from [Stimming_LLR]_. 758 """ 759 u_hat = msg_uhat[:, :, 0, ind_u[0]] 760 llr = msg_llr[:, :, 0, ind_u[0]] 761 762 llr_in = tf.clip_by_value(llr, 763 clip_value_min=-self._llr_max, 764 clip_value_max=self._llr_max) 765 766 # Numerically more stable implementation of log(1 + exp(-x)) 767 msg_pm += tf.math.softplus(-tf.multiply((1 - 2*u_hat), llr_in)) 768 return msg_pm 769 770 def _sort_decoders(self, msg_pm, msg_uhat, msg_llr): 771 """Sort decoders according to their path metric.""" 772 773 ind = tf.argsort(msg_pm, axis=-1) 774 775 msg_pm = tf.gather(msg_pm, ind, batch_dims=1, axis=None) 776 msg_uhat = tf.gather(msg_uhat, ind, batch_dims=1, axis=None) 777 msg_llr = tf.gather(msg_llr, ind, batch_dims=1, axis=None) 778 779 return msg_pm, msg_uhat, msg_llr 780 781 def _cn_op(self, x, y): 782 """Check-node update (boxplus) for LLR inputs. 783 784 Operations are performed element-wise. 785 786 See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. 787 """ 788 x_in = tf.clip_by_value(x, 789 clip_value_min=-self._llr_max, 790 clip_value_max=self._llr_max) 791 y_in = tf.clip_by_value(y, 792 clip_value_min=-self._llr_max, 793 clip_value_max=self._llr_max) 794 795 # Avoid division for numerical stability 796 # Implements log(1+e^(x+y)) 797 llr_out = tf.math.softplus((x_in + y_in)) 798 # Implements log(e^x+e^y) 799 llr_out -= tf.math.reduce_logsumexp(tf.stack([x_in, y_in], axis=-1), 800 axis=-1) 801 802 return llr_out 803 804 def _vn_op(self, x, y, u_hat): 805 """Variable node update for LLR inputs. 806 807 Operations are performed element-wise. 808 809 See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. 810 """ 811 return tf.multiply((1 - 2*u_hat), x) + y 812 813 def _duplicate_paths(self, msg_uhat, msg_llr, msg_pm): 814 """Duplicate paths by copying the upper branch into the lower one. 815 """ 816 msg_uhat = tf.tile(msg_uhat[:, :self._list_size, :, :], [1, 2, 1, 1]) 817 msg_llr = tf.tile(msg_llr[:, :self._list_size, :, :], [1, 2, 1, 1]) 818 msg_pm = tf.tile(msg_pm[:, :self._list_size], [1, 2]) 819 820 return msg_uhat, msg_llr, msg_pm 821 822 def _update_left_branch(self, msg_llr, stage_ind, cw_ind_left,cw_ind_right): 823 """Update messages of left branch. 824 825 Remark: Two versions are implemented (throughput vs. graph complexity): 826 1.) use tensor_scatter_nd_update 827 2.) explicitly split graph and concatenate again 828 """ 829 830 llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :], 831 cw_ind_left, 832 axis=2) 833 llr_right_in = tf.gather(msg_llr[:, :, stage_ind, :], 834 cw_ind_right, 835 axis=2) 836 837 llr_left_out = self._cn_op(llr_left_in, llr_right_in) 838 839 if self._use_scatter: 840 # self.msg_llr[:, :, stage_ind-1, cw_ind_left] = llr_left_out 841 842 # transpose such that batch-dim can be broadcasted 843 msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0]) 844 llr_left_out_s = tf.transpose(llr_left_out, [2, 1, 0]) 845 846 # generate index grid 847 stage_ind = tf.cast(stage_ind, tf.int64) 848 cw_ind_left = tf.cast(cw_ind_left, tf.int64) 849 grid = tf.meshgrid(stage_ind-1, cw_ind_left) 850 ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2]) 851 852 # update values 853 msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t, 854 ind, 855 llr_left_out_s) 856 857 # and restore original order 858 msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1]) 859 else: 860 # alternative solution with split/concatenation of graph 861 # llr_left = msg_llr[:, :, stage_ind, cw_ind_left] 862 llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :], 863 np.arange(0, cw_ind_left[0]), 864 axis=2) 865 866 llr_right = tf.gather(msg_llr[:, :, stage_ind-1, :], 867 cw_ind_right, 868 axis=2) 869 llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :], 870 np.arange(cw_ind_right[-1] +1, self._n), 871 axis=2) 872 873 llr_s = tf.concat([llr_left0, 874 llr_left_out, 875 llr_right, 876 llr_right1], 2) 877 878 llr_s = tf.expand_dims(llr_s, axis=2) 879 880 msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :] 881 msg_llr2 = msg_llr[:, :, stage_ind:, :] 882 msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2) 883 884 return msg_llr 885 886 def _update_right_branch(self, msg_llr, msg_uhat, stage_ind, cw_ind_left, 887 cw_ind_right): 888 """Update messages for right branch. 889 890 Remark: Two versions are implemented (throughput vs. graph complexity): 891 1.) use tensor_scatter_nd_update 892 2.) explicitly split graph and concatenate again 893 """ 894 u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :], 895 cw_ind_left, 896 axis=2) 897 898 llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :], 899 cw_ind_left, 900 axis=2) 901 902 llr_right = tf.gather(msg_llr[:, :, stage_ind, :], 903 cw_ind_right, 904 axis=2) 905 906 llr_right_out = self._vn_op(llr_left_in, llr_right, u_hat_left_up) 907 908 if self._use_scatter: 909 # transpose such that batch dim can be broadcasted 910 msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0]) 911 llr_right_out_s = tf.transpose(llr_right_out, [2, 1, 0]) 912 913 # generate index grid 914 stage_ind = tf.cast(stage_ind, tf.int64) 915 cw_ind_left = tf.cast(cw_ind_right, tf.int64) 916 grid = tf.meshgrid(stage_ind-1, cw_ind_right) 917 ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2]) 918 919 msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t, 920 ind, 921 llr_right_out_s) 922 923 # and restore original order 924 msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1]) 925 else: 926 # alternative solution with split/concatenation of graph 927 # llr_left = msg_llr[:, :, stage_ind, cw_ind_left] 928 llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :], 929 np.arange(0, cw_ind_left[0]), 930 axis=2) 931 llr_left = tf.gather(msg_llr[:, :, stage_ind-1, :], 932 cw_ind_left, 933 axis=2) 934 llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :], 935 np.arange(cw_ind_right[-1]+1, self._n), 936 axis=2) 937 938 llr_s = tf.concat([llr_left0, llr_left, llr_right_out,llr_right1],2) 939 llr_s = tf.expand_dims(llr_s, axis=2) 940 941 msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :] 942 msg_llr2 = msg_llr[:, :, stage_ind:, :] 943 944 msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2) 945 946 return msg_llr 947 948 def _update_branch_u(self, msg_uhat, stage_ind, cw_ind_left, cw_ind_right): 949 """Update ``u_hat`` messages after executing both branches. 950 951 Remark: Two versions are implemented (throughput vs. graph complexity): 952 1.) use tensor_scatter_nd_update 953 2.) explicitly split graph and concatenate again 954 """ 955 u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :], 956 cw_ind_left, 957 axis=2) 958 959 u_hat_right_up = tf.gather(msg_uhat[:, :, stage_ind-1, :], 960 cw_ind_right, 961 axis=2) 962 963 # combine u_hat via bitwise_xor (more efficient than mod2) 964 u_hat_left_up_int = tf.cast(u_hat_left_up, tf.int32) 965 u_hat_right_up_int = tf.cast(u_hat_right_up, tf.int32) 966 u_hat_left = tf.bitwise.bitwise_xor(u_hat_left_up_int, 967 u_hat_right_up_int) 968 u_hat_left = tf.cast(u_hat_left, tf.float32) 969 970 if self._use_scatter: 971 cw_ind = np.concatenate([cw_ind_left, cw_ind_right]) 972 973 u_hat = tf.concat([u_hat_left, u_hat_right_up], -1) 974 975 # self.msg_llr[:, stage_ind-1, cw_ind_left] = llr_left_out 976 977 # transpose such that batch dim can be broadcasted 978 msg_uhat_t = tf.transpose(msg_uhat, [2, 3, 1, 0]) 979 u_hat_s = tf.transpose(u_hat, [2, 1, 0]) 980 981 # generate index grid 982 stage_ind = tf.cast(stage_ind, tf.int64) 983 cw_ind = tf.cast(cw_ind, tf.int64) 984 grid = tf.meshgrid(stage_ind, cw_ind) 985 ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2]) 986 987 msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t, 988 ind, 989 u_hat_s) 990 991 # and restore original order 992 msg_uhat = tf.transpose(msg_uhat_s, [3, 2, 0, 1]) 993 else: 994 # alternative solution with split/concatenation of graph 995 u_hat_left_0 = tf.gather(msg_uhat[:, :, stage_ind, :], 996 np.arange(0, cw_ind_left[0]), 997 axis=2) 998 u_hat_right_1 = tf.gather(msg_uhat[:, :, stage_ind, :], 999 np.arange(cw_ind_right[-1]+1, self._n), 1000 axis=2) 1001 1002 u_hat = tf.concat([u_hat_left_0, 1003 u_hat_left, 1004 u_hat_right_up, 1005 u_hat_right_1], 2) 1006 1007 # provide u_hat for next higher stage 1008 msg_uhat1 = msg_uhat[:, :, 0:stage_ind, :] 1009 msg_uhat2 = msg_uhat[:, :, stage_ind+1:, :] 1010 u_hat = tf.expand_dims(u_hat, axis=2) 1011 1012 msg_uhat = tf.concat([msg_uhat1, u_hat, msg_uhat2], 2) 1013 1014 return msg_uhat 1015 1016 def _polar_decode_scl(self, cw_ind, msg_uhat, msg_llr, msg_pm): 1017 """Recursive decoding function for SCL decoding. 1018 1019 We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_ 1020 and branch the messages into a `left` and `right` update paths until 1021 reaching a leaf node. 1022 1023 Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the 1024 tree depth while maintaining the same output. 1025 """ 1026 # current sub-code length and stage index (= tree depth) 1027 n = len(cw_ind) 1028 stage_ind = int(np.log2(n)) 1029 1030 # recursively branch through decoding tree 1031 if n>1: 1032 # prune tree if rate-0 subcode is detected 1033 if self._use_fast_scl: 1034 if np.sum(self._frozen_ind[cw_ind])==n: 1035 msg_pm, msg_uhat, msg_llr = self._update_rate0_code(msg_pm, 1036 msg_uhat, 1037 msg_llr, 1038 cw_ind) 1039 return msg_uhat, msg_llr, msg_pm 1040 1041 if (self._frozen_ind[cw_ind[-1]]==0 and 1042 np.sum(self._frozen_ind[cw_ind[:-1]])==n-1): 1043 msg_pm, msg_uhat, msg_llr, = self._update_rep_code(msg_pm, 1044 msg_uhat, 1045 msg_llr, 1046 cw_ind) 1047 return msg_uhat, msg_llr, msg_pm 1048 1049 # split index into left and right part 1050 cw_ind_left = cw_ind[0:int(n/2)] 1051 cw_ind_right = cw_ind[int(n/2):] 1052 1053 # ----- left branch ----- 1054 msg_llr = self. _update_left_branch(msg_llr, 1055 stage_ind, 1056 cw_ind_left, 1057 cw_ind_right) 1058 1059 # call sub-graph decoder of left branch 1060 msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_left, 1061 msg_uhat, 1062 msg_llr, 1063 msg_pm) 1064 1065 # ----- right branch ----- 1066 msg_llr = self._update_right_branch(msg_llr, 1067 msg_uhat, 1068 stage_ind, 1069 cw_ind_left, 1070 cw_ind_right) 1071 1072 # call sub-graph decoder of right branch 1073 msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_right, 1074 msg_uhat, 1075 msg_llr, 1076 msg_pm) 1077 # update uhat at current stage 1078 msg_uhat = self._update_branch_u(msg_uhat, 1079 stage_ind, 1080 cw_ind_left, 1081 cw_ind_right) 1082 1083 # if leaf is reached perform basic decoding op (=decision) 1084 else: 1085 # update bit value at current position 1086 msg_uhat = self._update_single_bit(cw_ind, msg_uhat) 1087 1088 # update PM 1089 msg_pm = self._update_pm(cw_ind, msg_uhat, msg_llr, msg_pm) 1090 1091 if self._frozen_ind[cw_ind]==0: # position is non-frozen 1092 # sort list 1093 msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm, 1094 msg_uhat, 1095 msg_llr) 1096 1097 # duplicate l best decoders to pos l:2*l (kill other decoders) 1098 msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat, 1099 msg_llr, 1100 msg_pm) 1101 1102 return msg_uhat, msg_llr, msg_pm 1103 1104 def _decode_tf(self, llr_ch): 1105 """Main decoding function in TF. 1106 1107 Initializes memory and calls recursive decoding function. 1108 """ 1109 1110 batch_size = tf.shape(llr_ch)[0] 1111 1112 # allocate memory for all 2*list_size decoders 1113 msg_uhat = tf.zeros([batch_size, 1114 2*self._list_size, 1115 self._n_stages+1, 1116 self._n]) 1117 msg_llr = tf.zeros([batch_size, 1118 2*self._list_size, 1119 self._n_stages, 1120 self._n]) 1121 # init all 2*l decoders with same llr_ch 1122 llr_ch = tf.reshape(llr_ch, [-1, 1, 1, self._n]) 1123 llr_ch = tf.tile(llr_ch,[1, 2*self._list_size, 1, 1]) 1124 1125 # init last stage with llr_ch 1126 msg_llr = tf.concat([msg_llr, llr_ch], 2) 1127 1128 # init all remaining L-1 decoders with high penalty 1129 pm0 = tf.zeros([batch_size, 1]) 1130 pm1 = self._llr_max * tf.ones([batch_size, self._list_size-1]) 1131 msg_pm = tf.concat([pm0, pm1, pm0, pm1], 1) 1132 1133 # and call recursive graph function 1134 msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(self._cw_ind, 1135 msg_uhat, 1136 msg_llr, 1137 msg_pm) 1138 1139 # and sort output 1140 msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm, 1141 msg_uhat, 1142 msg_llr) 1143 return [msg_uhat, msg_pm] 1144 1145 #################################### 1146 # Helper functions for Numpy decoder 1147 #################################### 1148 1149 def _update_rate0_code_np(self, cw_ind): 1150 """Update rate-0 (i.e., all frozen) sub-code at pos ``cw_ind`` in Numpy. 1151 1152 See Eq. (26) in [Hashemi_SSCL]_. 1153 """ 1154 n = len(cw_ind) 1155 stage_ind = int(np.log2(n)) 1156 1157 # update PM for each batch sample 1158 ind = np.expand_dims(self._dec_pointer, axis=-1) 1159 llr_in = np.take_along_axis(self.msg_llr[:, :, stage_ind, cw_ind], 1160 ind, 1161 axis=1) 1162 1163 llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max) 1164 pm_val = np.log(1 + np.exp(-llr_clip)) 1165 self.msg_pm += np.sum(pm_val, axis=-1) 1166 1167 def _update_rep_code_np(self, cw_ind): 1168 """Update rep. code (i.e., only rightmost bit is non-frozen) 1169 sub-code at position ``ind_u`` in Numpy. 1170 1171 See Eq. (31) in [Hashemi_SSCL]_. 1172 """ 1173 n = len(cw_ind) 1174 stage_ind = int(np.log2(n)) 1175 bs = self._dec_pointer.shape[0] 1176 1177 # update PM 1178 llr = np.zeros([bs, 2*self._list_size, n]) 1179 for i in range(bs): 1180 llr_i = self.msg_llr[i, self._dec_pointer[i, :], stage_ind, :] 1181 llr[i, :, :] = llr_i[:, cw_ind] 1182 1183 # upper branch has negative llr values (bit is 1) 1184 llr[:, self._list_size:, :] = - llr[:, self._list_size:, :] 1185 llr_in = np.maximum(np.minimum(llr, self._llr_max), -self._llr_max) 1186 pm_val = np.sum(np.log(1 + np.exp(-llr_in)), axis=-1) 1187 self.msg_pm += pm_val 1188 1189 for i in range(bs): 1190 ind_dec = self._dec_pointer[i, self._list_size:] 1191 for j in cw_ind: 1192 self.msg_uhat[i, ind_dec, stage_ind, j] = 1 1193 1194 # branch last bit and update pm at pos cw_ind[-1] 1195 self._update_single_bit_np([cw_ind[-1]]) 1196 self._sort_decoders_np() 1197 self._duplicate_paths_np() 1198 1199 def _update_single_bit_np(self, ind_u): 1200 """Update single bit at position ``ind_u`` of all decoders in Numpy.""" 1201 1202 if self._frozen_ind[ind_u]==0: # position is non-frozen 1203 ind_dec = np.expand_dims(self._dec_pointer[:, self._list_size:], 1204 axis=-1) 1205 uhat_slice = self.msg_uhat[:, :, 0, ind_u] 1206 np.put_along_axis(uhat_slice, ind_dec, 1., axis=1) 1207 self.msg_uhat[:, :, 0, ind_u] = uhat_slice 1208 1209 1210 def _update_pm_np(self, ind_u): 1211 """ Update path metric of all decoders at bit position ``ind_u`` in 1212 Numpy. 1213 1214 We apply Eq. (10) from [Stimming_LLR]_. 1215 """ 1216 ind = np.expand_dims(self._dec_pointer, axis=-1) 1217 u_hat = np.take_along_axis(self.msg_uhat[:, :, 0, ind_u], ind, axis=1) 1218 u_hat = np.squeeze(u_hat, axis=-1) 1219 llr_in = np.take_along_axis(self.msg_llr[:, :, 0, ind_u], ind, axis=1) 1220 llr_in = np.squeeze(llr_in, axis=-1) 1221 1222 llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max) 1223 self.msg_pm += np.log(1 + np.exp(-np.multiply((1-2*u_hat), llr_clip))) 1224 1225 def _sort_decoders_np(self): 1226 """Sort decoders according to their path metric.""" 1227 1228 ind = np.argsort(self.msg_pm, axis=-1) 1229 self.msg_pm = np.take_along_axis(self.msg_pm, ind, axis=1) 1230 self._dec_pointer = np.take_along_axis(self._dec_pointer, ind, axis=1) 1231 1232 def _cn_op_np(self, x, y): 1233 """Check node update (boxplus) for LLRs in Numpy. 1234 1235 See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. 1236 """ 1237 x_in = np.maximum(np.minimum(x, self._llr_max), -self._llr_max) 1238 y_in = np.maximum(np.minimum(y, self._llr_max), -self._llr_max) 1239 1240 # avoid division for numerical stability 1241 llr_out = np.log(1 + np.exp(x_in + y_in)) 1242 llr_out -= np.log(np.exp(x_in) + np.exp(y_in)) 1243 1244 return llr_out 1245 1246 def _vn_op_np(self, x, y, u_hat): 1247 """Variable node update (boxplus) for LLRs in Numpy.""" 1248 return np.multiply((1-2*u_hat), x) + y 1249 1250 def _duplicate_paths_np(self): 1251 """Copy first ``list_size``/2 paths into lower part in Numpy. 1252 1253 Decoder indices are encoded in ``self._dec_pointer``. 1254 """ 1255 ind_low = self._dec_pointer[:, :self._list_size] 1256 ind_up = self._dec_pointer[:, self._list_size:] 1257 1258 for i in range(ind_up.shape[0]): 1259 self.msg_uhat[i, ind_up[i,:], :, :] = self.msg_uhat[i, 1260 ind_low[i,:], 1261 :, :] 1262 self.msg_llr[i, ind_up[i,:],:,:] = self.msg_llr[i, ind_low[i,:],:,:] 1263 1264 # pm must be sorted directly (not accessed via pointer) 1265 self.msg_pm[:, self._list_size:] = self.msg_pm[:, :self._list_size] 1266 1267 def _polar_decode_scl_np(self, cw_ind): 1268 """Recursive decoding function in Numpy. 1269 1270 We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_ 1271 and branch the messages into a `left` and `right` update paths until 1272 reaching a leaf node. 1273 1274 Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the 1275 tree depth while maintaining the same output. 1276 """ 1277 n = len(cw_ind) 1278 stage_ind = int(np.log2(n)) 1279 1280 # recursively branch through decoding tree 1281 if n>1: 1282 # prune tree if rate-0 subcode or rep-code is detected 1283 if self._use_fast_scl: 1284 if np.sum(self._frozen_ind[cw_ind])==n: 1285 # rate0 code detected 1286 self._update_rate0_code_np(cw_ind) 1287 return 1288 if (self._frozen_ind[cw_ind[-1]]==0 and 1289 np.sum(self._frozen_ind[cw_ind[:-1]])==n-1): 1290 # rep code detected 1291 self._update_rep_code_np(cw_ind) 1292 return 1293 cw_ind_left = cw_ind[0:int(n/2)] 1294 cw_ind_right = cw_ind[int(n/2):] 1295 1296 # ----- left branch ----- 1297 llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left] 1298 llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right] 1299 1300 self.msg_llr[:, :, stage_ind-1, cw_ind_left] = self._cn_op_np( 1301 llr_left, 1302 llr_right) 1303 1304 # call left branch decoder 1305 self._polar_decode_scl_np(cw_ind_left) 1306 1307 # ----- right branch ----- 1308 u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left] 1309 llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left] 1310 llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right] 1311 1312 self.msg_llr[:, :, stage_ind-1, cw_ind_right] = self._vn_op_np( 1313 llr_left, 1314 llr_right, 1315 u_hat_left_up) 1316 1317 # call right branch decoder 1318 self._polar_decode_scl_np(cw_ind_right) 1319 1320 # combine u_hat 1321 u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left] 1322 u_hat_right_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_right] 1323 1324 # u_hat_left_up XOR u_hat_right_up 1325 u_hat_left = (u_hat_left_up != u_hat_right_up) + 0 1326 1327 u_hat = np.concatenate([u_hat_left, u_hat_right_up], axis=-1) 1328 1329 # provide u_hat for next higher stage 1330 self.msg_uhat[:, :, stage_ind, cw_ind] = u_hat 1331 1332 else: # if leaf is reached perform basic decoding op (=decision) 1333 1334 self._update_single_bit_np(cw_ind) 1335 1336 # update PM 1337 self._update_pm_np(cw_ind) 1338 1339 # position is non-frozen 1340 if self._frozen_ind[cw_ind]==0: 1341 # sort list 1342 self._sort_decoders_np() 1343 # duplicate the best list_size decoders 1344 self._duplicate_paths_np() 1345 return 1346 1347 def _decode_np_batch(self, llr_ch): 1348 """Decode batch of ``llr_ch`` with Numpy decoder.""" 1349 1350 bs = llr_ch.shape[0] 1351 1352 # allocate memory for all 2*list_size decoders 1353 self.msg_uhat = np.zeros([bs, 1354 2*self._list_size, 1355 self._n_stages+1, 1356 self._n]) 1357 self.msg_llr = np.zeros([bs, 1358 2*self._list_size, 1359 self._n_stages+1, 1360 self._n]) 1361 self.msg_pm = np.zeros([bs, 1362 2*self._list_size]) 1363 1364 # L-1 decoders start with high penalty 1365 self.msg_pm[:,1:self._list_size] = self._llr_max 1366 # same for the second half of the L-1 decoders 1367 self.msg_pm[:,self._list_size+1:] = self._llr_max 1368 1369 # use pointers to avoid in-memory sorting 1370 self._dec_pointer = np.arange(2*self._list_size) 1371 self._dec_pointer = np.tile(np.expand_dims(self._dec_pointer, axis=0), 1372 [bs,1]) 1373 1374 # init llr_ch (broadcast via list dimension) 1375 self.msg_llr[:, :, self._n_stages, :] = np.expand_dims(llr_ch, axis=1) 1376 1377 # call recursive graph function 1378 self._polar_decode_scl_np(self._cw_ind) 1379 1380 # select most likely candidate 1381 self._sort_decoders_np() 1382 1383 # remove pointers 1384 for ind in range(bs): 1385 self.msg_uhat[ind, :, :, :] = self.msg_uhat[ind, 1386 self._dec_pointer[ind], 1387 :, :] 1388 return self.msg_uhat, self.msg_pm 1389 1390 def _decode_np_hybrid(self, llr_ch, u_hat_sc, crc_valid): 1391 """Hybrid SCL decoding stage that decodes iff CRC from previous SC 1392 decoding attempt failed. 1393 1394 This option avoids the usage of the high-complexity SCL decoder in cases 1395 where SC would be sufficient. For further details we refer to 1396 [Cammerer_Hybrid_SCL]_ (we use SC instead of the proposed BP stage). 1397 1398 Remark: This decoder does not exactly implement SCL as the CRC 1399 can be false positive after the SC stage. However, in these cases 1400 SCL+CRC may also yield the wrong results. 1401 1402 Remark 2: Due to the excessive control flow (if/else) and the 1403 varying batch-sizes, this function is only available as Numpy 1404 decoder (i.e., runs on the CPU). 1405 """ 1406 1407 bs = llr_ch.shape[0] 1408 crc_valid = np.squeeze(crc_valid, axis=-1) 1409 # index of codewords that need SCL decoding 1410 ind_invalid = np.arange(bs)[np.invert(crc_valid)] 1411 1412 # init SCL decoder for bs_hyb samples requiring SCL dec. 1413 llr_ch_hyb = np.take(llr_ch, ind_invalid, axis=0) 1414 msg_uhat_hyb, msg_pm_hyb = self._decode_np_batch(llr_ch_hyb) 1415 1416 # merge results with previously decoded SC results 1417 msg_uhat = np.zeros([bs, 2*self._list_size, 1, self._n]) 1418 msg_pm = np.ones([bs, 2*self._list_size]) * self._llr_max * self.k 1419 msg_pm[:, 0] = 0 1420 1421 # copy SC data 1422 msg_uhat[:, 0, 0, self._info_pos] = u_hat_sc 1423 1424 ind_hyb = 0 1425 for ind in range(bs): 1426 if not crc_valid[ind]: 1427 #copy data from SCL 1428 msg_uhat[ind, :, 0, :] = msg_uhat_hyb[ind_hyb, :, 0, :] 1429 msg_pm[ind, :] = msg_pm_hyb[ind_hyb, :] 1430 ind_hyb += 1 1431 1432 return msg_uhat, msg_pm 1433 1434 ######################### 1435 # Keras layer functions 1436 ######################### 1437 1438 def build(self, input_shape): 1439 """Build and check if shape of input is invalid.""" 1440 assert (input_shape[-1]==self._n), "Invalid input shape." 1441 assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.' 1442 1443 def call(self, inputs): 1444 """Successive cancellation list (SCL) decoding function. 1445 1446 This function performs successive cancellation list decoding 1447 and returns the estimated information bits. 1448 1449 An outer CRC can be applied optionally by setting ``crc_degree``. 1450 1451 Args: 1452 inputs (tf.float32): Tensor of shape `[...,n]` containing the 1453 channel LLR values (as logits). 1454 1455 Returns: 1456 `tf.float32`: Tensor of shape `[...,k]` containing 1457 hard-decided estimations of all ``k`` information bits. 1458 1459 Raises: 1460 ValueError: If ``inputs`` is not of shape `[..., n]` 1461 or `dtype` is not `tf.float32`. 1462 1463 InvalidArgumentError: When rank(``inputs``)<2. 1464 1465 Note: 1466 This function recursively unrolls the SCL decoding tree, thus, 1467 for larger values of ``n`` building the decoding graph can become 1468 time consuming. Please consider the ``cpu_only`` option instead. 1469 """ 1470 1471 tf.debugging.assert_type(inputs, self._output_dtype, 1472 "Invalid input dtype.") 1473 # internal calculations still in tf.float32 1474 inputs = tf.cast(inputs, tf.float32) 1475 1476 # last dim must be of length n 1477 tf.debugging.assert_equal(tf.shape(inputs)[-1], 1478 self._n, 1479 "Last input dimension must be of length n.") 1480 1481 # Reshape inputs to [-1, n] 1482 tf.debugging.assert_greater(tf.rank(inputs), 1) 1483 input_shape = inputs.shape 1484 new_shape = [-1, self._n] 1485 llr_ch = tf.reshape(inputs, new_shape) 1486 1487 llr_ch = -1. * llr_ch # logits are converted into "true" llrs 1488 1489 # if activated use Numpy decoder 1490 if self._use_hybrid_sc: 1491 # use SC decoder to decode first 1492 u_hat = self._decoder_sc(-llr_ch) 1493 _, crc_valid = self._crc_decoder(u_hat) 1494 msg_uhat, msg_pm = tf.py_function(func=self._decode_np_hybrid, 1495 inp=[llr_ch, u_hat, crc_valid], 1496 Tout=[tf.float32, tf.float32]) 1497 # note: return shape is only 1 in 3. dim (to avoid copy overhead) 1498 msg_uhat = tf.reshape(msg_uhat, [-1, 2*self._list_size, 1, self._n]) 1499 msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size]) 1500 else: 1501 if self._cpu_only: 1502 msg_uhat, msg_pm = tf.py_function(func=self._decode_np_batch, 1503 inp=[llr_ch], 1504 Tout=[tf.float32, tf.float32]) 1505 # restore shape information 1506 msg_uhat = tf.reshape(msg_uhat, 1507 [-1, 2*self._list_size, self._n_stages+1, self._n]) 1508 msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size]) 1509 else: 1510 msg_uhat, msg_pm = self._decode_tf(llr_ch) 1511 1512 # check CRC (and remove CRC parity bits) 1513 if self._use_crc: 1514 u_hat_list = tf.gather(msg_uhat[:, :, 0, :], 1515 self._info_pos, 1516 axis=-1) 1517 # undo input bit interleaving 1518 # remark: the output is not interleaved for compatibility with SC 1519 if self._iil: 1520 u_hat_list_crc = tf.gather(u_hat_list, 1521 self._ind_iil_inv, 1522 axis=-1) 1523 else: # no interleaving applied 1524 u_hat_list_crc = u_hat_list 1525 1526 _, crc_valid = self._crc_decoder(u_hat_list_crc) 1527 # add penalty to pm if CRC fails 1528 pm_penalty = ((1. - tf.cast(crc_valid, tf.float32)) 1529 * self._llr_max * self.k) 1530 msg_pm += tf.squeeze(pm_penalty, axis=2) 1531 1532 # select most likely candidate 1533 cand_ind = tf.argmin(msg_pm, axis=-1) 1534 c_hat = tf.gather(msg_uhat[:, :, 0, :], cand_ind, axis=1, batch_dims=1) 1535 u_hat = tf.gather(c_hat, self._info_pos, axis=-1) 1536 1537 # and reconstruct input shape 1538 output_shape = input_shape.as_list() 1539 output_shape[-1] = self.k 1540 output_shape[0] = -1 # first dim can be dynamic (None) 1541 u_hat_reshape = tf.reshape(u_hat, output_shape) 1542 1543 if self._return_crc_status: 1544 # reconstruct CRC status 1545 crc_status = tf.gather(crc_valid, cand_ind, axis=1, batch_dims=1) 1546 # reconstruct shape 1547 output_shape.pop() # remove last dimension 1548 crc_status = tf.reshape(crc_status, output_shape) 1549 1550 crc_status = tf.cast(crc_status, self._output_dtype) 1551 # return info bits and CRC status 1552 return tf.cast(u_hat_reshape, self._output_dtype), crc_status 1553 else: # return only info bits 1554 return tf.cast(u_hat_reshape, self._output_dtype) 1555 1556 1557 class PolarBPDecoder(Layer): 1558 # pylint: disable=line-too-long 1559 """PolarBPDecoder(frozen_pos, n, num_iter=20, hard_out=True, output_dtype=tf.float32, **kwargs) 1560 1561 Belief propagation (BP) decoder for Polar codes [Arikan_Polar]_ and 1562 Polar-like codes based on [Arikan_BP]_ and [Forney_Graphs]_. 1563 1564 The class inherits from the Keras layer class and can be used as layer in a 1565 Keras model. 1566 1567 Remark: The PolarBPDecoder does currently not support XLA. 1568 1569 Parameters 1570 ---------- 1571 frozen_pos: ndarray 1572 Array of `int` defining the ``n-k`` indices of the frozen positions. 1573 1574 n: int 1575 Defining the codeword length. 1576 1577 num_iter: int 1578 Defining the number of decoder iterations (no early stopping used 1579 at the moment). 1580 1581 hard_out: bool 1582 Defaults to True. If True, the decoder provides hard-decided 1583 information bits instead of soft-values. 1584 1585 output_dtype: tf.DType 1586 Defaults to tf.float32. Defines the output datatype of the layer 1587 (internal precision remains tf.float32). 1588 1589 Input 1590 ----- 1591 inputs: [...,n], tf.float32 1592 2+D tensor containing the channel logits/llr values. 1593 1594 Output 1595 ------ 1596 : [...,k], tf.float32 1597 2+D tensor containing bit-wise soft-estimates 1598 (or hard-decided bit-values) of all ``k`` information bits. 1599 1600 Raises 1601 ------ 1602 AssertionError 1603 If ``n`` is not `int`. 1604 1605 AssertionError 1606 If ``n`` is not a power of 2. 1607 1608 AssertionError 1609 If the number of elements in ``frozen_pos`` is greater than ``n``. 1610 1611 AssertionError 1612 If ``frozen_pos`` does not consists of `int`. 1613 1614 AssertionError 1615 If ``hard_out`` is not `bool`. 1616 1617 ValueError 1618 If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}. 1619 1620 AssertionError 1621 If ``num_iter`` is not `int`. 1622 1623 AssertionError 1624 If ``num_iter`` is not a positive value. 1625 1626 Note 1627 ---- 1628 This decoder is fully differentiable and, thus, well-suited for 1629 gradient descent-based learning tasks such as `learned code design` 1630 [Ebada_Design]_. 1631 1632 As commonly done, we assume frozen bits are set to `0`. Please note 1633 that - although its practical relevance is only little - setting frozen 1634 bits to `1` may result in `affine` codes instead of linear code as the 1635 `all-zero` codeword is not necessarily part of the code any more. 1636 1637 """ 1638 1639 def __init__(self, 1640 frozen_pos, 1641 n, 1642 num_iter=20, 1643 hard_out=True, 1644 output_dtype=tf.float32, 1645 **kwargs): 1646 1647 if output_dtype not in (tf.float16, tf.float32, tf.float64): 1648 raise ValueError( 1649 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 1650 1651 if output_dtype is not tf.float32: 1652 print('Note: decoder uses tf.float32 for internal calculations.') 1653 1654 super().__init__(dtype=output_dtype, **kwargs) 1655 self._output_dtype = output_dtype 1656 1657 # assert error if r>1 or k, n are negative 1658 assert isinstance(n, numbers.Number), "n must be a number." 1659 n = int(n) # n can be float (e.g. as result of n=k*r) 1660 assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int." 1661 assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \ 1662 "be greater than n." 1663 assert np.log2(n)==int(np.log2(n)), "n must be a power of 2." 1664 1665 assert isinstance(hard_out, bool), "hard_out must be boolean." 1666 1667 # store internal attributes 1668 self._n = n 1669 self._frozen_pos = frozen_pos 1670 self._k = self._n - len(self._frozen_pos) 1671 self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos) 1672 assert self._k==len(self._info_pos), "Internal error: invalid " \ 1673 "info_pos generated." 1674 1675 assert isinstance(num_iter, int), "num_iter must be integer." 1676 assert num_iter>0, "num_iter must be a positive value." 1677 self._num_iter = tf.constant(num_iter, dtype=tf.int32) 1678 1679 self._llr_max = 19.3 # internal max LLR value 1680 self._hard_out = hard_out 1681 1682 # depth of decoding graph 1683 self._n_stages = int(np.log2(self._n)) 1684 1685 ######################################### 1686 # Public methods and properties 1687 ######################################### 1688 1689 @property 1690 def n(self): 1691 """Codeword length.""" 1692 return self._n 1693 1694 @property 1695 def k(self): 1696 """Number of information bits.""" 1697 return self._k 1698 1699 @property 1700 def frozen_pos(self): 1701 """Frozen positions for Polar decoding.""" 1702 return self._frozen_pos 1703 1704 @property 1705 def info_pos(self): 1706 """Information bit positions for Polar encoding.""" 1707 return self._info_pos 1708 1709 @property 1710 def llr_max(self): 1711 """Maximum LLR value for internal calculations.""" 1712 return self._llr_max 1713 1714 @property 1715 def num_iter(self): 1716 """Number of decoding iterations.""" 1717 return self._num_iter 1718 1719 @property 1720 def hard_out(self): 1721 """Indicates if decoder hard-decides outputs.""" 1722 return self._hard_out 1723 1724 @property 1725 def output_dtype(self): 1726 """Output dtype of decoder.""" 1727 return self._output_dtype 1728 1729 @num_iter.setter 1730 def num_iter(self, num_iter): 1731 "Number of decoding iterations." 1732 assert isinstance(num_iter, int), 'num_iter must be int.' 1733 assert num_iter>=0, 'num_iter cannot be negative.' 1734 self._num_iter = tf.constant(num_iter, dtype=tf.int32) 1735 1736 ######################### 1737 # Utility methods 1738 ######################### 1739 1740 def _boxplus_tf(self, x, y): 1741 """Check-node update (boxplus) for LLR inputs. 1742 1743 Operations are performed element-wise. 1744 """ 1745 x_in = tf.clip_by_value(x, 1746 clip_value_min=-self._llr_max, 1747 clip_value_max=self._llr_max) 1748 y_in = tf.clip_by_value(y, 1749 clip_value_min=-self._llr_max, 1750 clip_value_max=self._llr_max) 1751 1752 # avoid division for numerical stability 1753 llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in)) 1754 llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in)) 1755 1756 return llr_out 1757 1758 def _decode_bp(self, llr_ch, num_iter): 1759 """Iterative BP decoding function with LLR-values. 1760 1761 Args: 1762 llr_ch (tf.float32): Tensor of shape `[batch_size, n]` containing 1763 the channel logits/llr values where `batch_size` denotes the 1764 batch-size. 1765 1766 num_iter (int): Defining the number of decoder iteration 1767 (no early stopping used at the moment). 1768 Returns: 1769 `tf.float32`: Tensor of shape `[batch_size, k]` containing 1770 bit-wise soft-estimates (or hard-decided bit-values) of all 1771 information bits. 1772 """ 1773 1774 bs = tf.shape(llr_ch)[0] 1775 1776 # store intermediate Tensors in TensorArray 1777 msg_l = tf.TensorArray(tf.float32, 1778 size=num_iter*(self._n_stages+1), 1779 dynamic_size=False, 1780 clear_after_read=False) 1781 1782 msg_r = tf.TensorArray(tf.float32, 1783 size=num_iter*(self._n_stages+1), 1784 dynamic_size=False, 1785 clear_after_read=False) 1786 1787 # init frozen positions with infinity 1788 msg_r_in = np.zeros([1, self._n]) 1789 msg_r_in[:, self._frozen_pos] = self._llr_max 1790 # copy for all batch-samples 1791 msg_r_in = tf.tile(tf.constant(msg_r_in, tf.float32), [bs, 1]) 1792 1793 # perform decoding iterations 1794 for ind_it in tf.range(self._num_iter): 1795 # update left-to-right messages 1796 for ind_s in range(self._n_stages): 1797 # calc indices 1798 ind_range = np.arange(int(self._n/2)) 1799 ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s) 1800 ind_2 = ind_1 + 2**ind_s 1801 # simplify gather with concatenated outputs 1802 ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0)) 1803 1804 # load incoming l messages 1805 if ind_s==self._n_stages-1: 1806 l1_in = tf.gather(llr_ch, ind_1, axis=1) 1807 l2_in = tf.gather(llr_ch, ind_2, axis=1) 1808 elif ind_it==0: 1809 l1_in = tf.zeros([bs, int(self._n/2)]) 1810 l2_in = tf.zeros([bs, int(self._n/2)]) 1811 else: 1812 l_in = msg_l.read((ind_s+1) + (ind_it-1)*(self._n_stages+1)) 1813 l1_in = tf.gather(l_in, ind_1, axis=1) 1814 l2_in = tf.gather(l_in, ind_2, axis=1) 1815 1816 # load incoming r messages 1817 if ind_s==0: 1818 r1_in = tf.gather(msg_r_in, ind_1, axis=1) 1819 r2_in = tf.gather(msg_r_in, ind_2, axis=1) 1820 else: 1821 r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1)) 1822 r1_in = tf.gather(r_in, ind_1, axis=1) 1823 r2_in = tf.gather(r_in, ind_2, axis=1) 1824 1825 r1_out = self._boxplus_tf(r1_in, l2_in + r2_in) 1826 r2_out = self._boxplus_tf(r1_in, l1_in) + r2_in 1827 1828 # and re-concatenate output 1829 r_out = tf.concat([r1_out, r2_out], 1) 1830 r_out = tf.gather(r_out, ind_inv, axis=1) 1831 msg_r = msg_r.write((ind_s+1) 1832 + ind_it*(self._n_stages+1), r_out) 1833 1834 # update right-to-left messages 1835 for ind_s in range(self._n_stages-1, -1, -1): 1836 ind_range = np.arange(int(self._n/2)) 1837 ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s) 1838 ind_2 = ind_1 + 2**ind_s 1839 ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0)) 1840 1841 # load messages 1842 if ind_s==self._n_stages-1: 1843 l1_in = tf.gather(llr_ch, ind_1, axis=1) 1844 l2_in = tf.gather(llr_ch, ind_2, axis=1) 1845 else: 1846 l_in = msg_l.read((ind_s+1)+ind_it*(self._n_stages+1)) 1847 l1_in = tf.gather(l_in, ind_1, axis=1) 1848 l2_in = tf.gather(l_in, ind_2, axis=1) 1849 1850 if ind_s==0: 1851 r1_in = tf.gather(msg_r_in, ind_1, axis=1) 1852 r2_in = tf.gather(msg_r_in, ind_2, axis=1) 1853 else: 1854 r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1)) 1855 r1_in = tf.gather(r_in, ind_1, axis=1) 1856 r2_in = tf.gather(r_in, ind_2, axis=1) 1857 1858 # node update functions 1859 l1_out = self._boxplus_tf(l1_in, l2_in + r2_in) 1860 l2_out = self._boxplus_tf(r1_in, l1_in) + l2_in 1861 1862 l_out = tf.concat([l1_out, l2_out], 1) 1863 l_out = tf.gather(l_out, ind_inv, axis=1) 1864 msg_l = msg_l.write(ind_s + ind_it*(self._n_stages+1), l_out) 1865 1866 # recover u_hat 1867 u_hat = tf.gather(msg_l.read((num_iter-1)*(self._n_stages+1)), 1868 self._info_pos, 1869 axis=1) 1870 # if active, hard-decide output bits 1871 if self._hard_out: 1872 u_hat = tf.where(u_hat>0, 0., 1.) 1873 else: # re-transform soft output to logits (instead of llrs) 1874 u_hat = -1. * u_hat 1875 return u_hat 1876 1877 ######################### 1878 # Keras layer functions 1879 ######################### 1880 1881 def build(self, input_shape): 1882 """Build and check if shape of input is invalid.""" 1883 assert (input_shape[-1]==self._n), "Invalid input shape" 1884 assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.' 1885 1886 def call(self, inputs): 1887 """Iterative BP decoding function. 1888 1889 This function performs `num_iter` belief propagation decoding iterations 1890 and returns the estimated information bits. 1891 1892 Args: 1893 inputs (tf.float32): Tensor of shape `[...,n]` containing the 1894 channel logits/llr values. 1895 1896 Returns: 1897 `tf.float32`: Tensor of shape `[...,k]` containing 1898 bit-wise soft-estimates (or hard-decided bit-values) of all 1899 ``k`` information bits. 1900 1901 Raises: 1902 ValueError: If ``inputs`` is not of shape `[..., n]` 1903 or `dtype` is not `output_dtype`. 1904 1905 InvalidArgumentError: When rank(``inputs``)<2. 1906 1907 Note: 1908 This function recursively unrolls the BP decoding graph, thus, 1909 for larger values of ``n`` or more iterations, building the 1910 decoding graph can become time and memory consuming. 1911 """ 1912 1913 tf.debugging.assert_type(inputs, self._output_dtype, 1914 "Invalid input dtype.") 1915 # internal calculations still in tf.float32 1916 inputs = tf.cast(inputs, tf.float32) 1917 1918 # Reshape inputs to [-1, n] 1919 input_shape = inputs.shape 1920 new_shape = [-1, self._n] 1921 llr_ch = tf.reshape(inputs, new_shape) 1922 1923 llr_ch = -1. * llr_ch # logits are converted into "true" llrs 1924 1925 # and decode 1926 u_hat = self._decode_bp(llr_ch, self._num_iter) 1927 1928 # and reconstruct input shape 1929 output_shape = input_shape.as_list() 1930 output_shape[-1] = self.k 1931 output_shape[0] = -1 # first dim can be dynamic (None) 1932 u_hat_reshape = tf.reshape(u_hat, output_shape) 1933 return tf.cast(u_hat_reshape, self._output_dtype) 1934 1935 1936 class Polar5GDecoder(Layer): 1937 # pylint: disable=line-too-long 1938 """Polar5GDecoder(enc_polar, dec_type="SC", list_size=8, num_iter=20,return_crc_status=False, output_dtype=tf.float32, **kwargs) 1939 1940 Wrapper for 5G compliant decoding including rate-recovery and CRC removal. 1941 1942 The class inherits from the Keras layer class and can be used as layer in a 1943 Keras model. 1944 1945 Parameters 1946 ---------- 1947 enc_polar: Polar5GEncoder 1948 Instance of the :class:`~sionna.fec.polar.encoding.Polar5GEncoder` 1949 used for encoding including rate-matching. 1950 1951 dec_type: str 1952 Defaults to `"SC"`. Defining the decoder to be used. 1953 Must be one of the following `{"SC", "SCL", "hybSCL", "BP"}`. 1954 1955 list_size: int 1956 Defaults to 8. Defining the list size `iff` list-decoding is used. 1957 Only required for ``dec_types`` `{"SCL", "hybSCL"}`. 1958 1959 num_iter: int 1960 Defaults to 20. Defining the number of BP iterations. Only required 1961 for ``dec_type`` `"BP"`. 1962 1963 return_crc_status: bool 1964 Defaults to False. If True, the decoder additionally returns the 1965 CRC status indicating if a codeword was (most likely) correctly 1966 recovered. 1967 1968 output_dtype: tf.DType 1969 Defaults to tf.float32. Defines the output datatype of the layer 1970 (internal precision remains tf.float32). 1971 1972 Input 1973 ----- 1974 inputs: [...,n], tf.float32 1975 2+D tensor containing the channel logits/llr values. 1976 1977 Output 1978 ------ 1979 1980 b_hat : [...,k], tf.float32 1981 2+D tensor containing hard-decided estimations of all `k` 1982 information bits. 1983 1984 crc_status : [...], tf.bool 1985 CRC status indicating if a codeword was (most likely) correctly 1986 recovered. This is only returned if ``return_crc_status`` is True. 1987 Note that false positives are possible. 1988 Raises 1989 ------ 1990 AssertionError 1991 If ``enc_polar`` is not `Polar5GEncoder`. 1992 1993 ValueError 1994 If ``dec_type`` is not `{"SC", "SCL", "SCL8", "SCL32", "hybSCL", 1995 "BP"}`. 1996 1997 AssertionError 1998 If ``dec_type`` is not `str`. 1999 2000 ValueError 2001 If ``inputs`` is not of shape `[..., n]` or `dtype` is not 2002 the same as ``output_dtype``. 2003 2004 InvalidArgumentError 2005 When rank(``inputs``)<2. 2006 2007 Note 2008 ---- 2009 This layer supports the uplink and downlink Polar rate-matching scheme 2010 without `codeword segmentation`. 2011 2012 Although the decoding `list size` is not provided by 3GPP 2013 [3GPPTS38212]_, the consortium has agreed on a `list size` of 8 for the 2014 5G decoding reference curves [Bioglio_Design]_. 2015 2016 All list-decoders apply `CRC-aided` decoding, however, the non-list 2017 decoders (`"SC"` and `"BP"`) cannot materialize the CRC leading to an 2018 effective rate-loss. 2019 2020 """ 2021 2022 def __init__(self, 2023 enc_polar, 2024 dec_type="SC", 2025 list_size=8, 2026 num_iter=20, 2027 return_crc_status=False, 2028 output_dtype=tf.float32, 2029 **kwargs): 2030 2031 if output_dtype not in (tf.float16, tf.float32, tf.float64): 2032 raise ValueError( 2033 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 2034 2035 if output_dtype is not tf.float32: 2036 print('Note: decoder uses tf.float32 for internal calculations.') 2037 self._output_dtype = output_dtype 2038 2039 super().__init__(dtype=output_dtype, **kwargs) 2040 2041 assert isinstance(enc_polar, Polar5GEncoder), \ 2042 "enc_polar must be Polar5GEncoder." 2043 assert isinstance(dec_type, str), "dec_type must be str." 2044 # list_size and num_iter are not checked here (done during decoder init) 2045 2046 # Store internal attributes 2047 self._n_target = enc_polar.n_target 2048 self._k_target = enc_polar.k_target 2049 self._n_polar = enc_polar.n_polar 2050 self._k_polar = enc_polar.k_polar 2051 self._k_crc = enc_polar.enc_crc.crc_length 2052 self._bil = enc_polar._channel_type == "uplink" 2053 self._iil = enc_polar._channel_type == "downlink" 2054 self._llr_max = 100 # Internal max LLR value (for punctured positions) 2055 self._enc_polar = enc_polar 2056 self._dec_type = dec_type 2057 2058 # Initialize the de-interleaver patterns 2059 self._init_interleavers() 2060 2061 # Initialize decoder 2062 if dec_type=="SC": 2063 print("Warning: 5G Polar codes use an integrated CRC that " \ 2064 "cannot be materialized with SC decoding and, thus, " \ 2065 "causes a degraded performance. Please consider SCL " \ 2066 "decoding instead.") 2067 self._polar_dec = PolarSCDecoder(self._enc_polar.frozen_pos, 2068 self._n_polar) 2069 elif dec_type=="SCL": 2070 self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos, 2071 self._n_polar, 2072 crc_degree=self._enc_polar.enc_crc.crc_degree, 2073 list_size=list_size, 2074 ind_iil_inv = self.ind_iil_inv) 2075 elif dec_type=="hybSCL": 2076 self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos, 2077 self._n_polar, 2078 crc_degree=self._enc_polar.enc_crc.crc_degree, 2079 list_size=list_size, 2080 use_hybrid_sc=True, 2081 ind_iil_inv = self.ind_iil_inv) 2082 elif dec_type=="BP": 2083 print("Warning: 5G Polar codes use an integrated CRC that " \ 2084 "cannot be materialized with BP decoding and, thus, " \ 2085 "causes a degraded performance. Please consider SCL " \ 2086 " decoding instead.") 2087 assert isinstance(num_iter, int), "num_iter must be int." 2088 assert num_iter > 0, "num_iter must be positive." 2089 self._num_iter = num_iter 2090 self._polar_dec = PolarBPDecoder(self._enc_polar.frozen_pos, 2091 self._n_polar, 2092 num_iter=num_iter, 2093 hard_out=True) 2094 else: 2095 raise ValueError("Unknown value for dec_type.") 2096 2097 assert isinstance(return_crc_status, bool), \ 2098 "return_crc_status must be bool." 2099 2100 self._return_crc_status = return_crc_status 2101 if self._return_crc_status: # init crc decoder 2102 if dec_type in ("SCL", "hybSCL"): 2103 # re-use CRC decoder from list decoder 2104 self._dec_crc = self._polar_dec._crc_decoder 2105 else: # init new CRC decoder for BP and SC 2106 self._dec_crc = CRCDecoder(self._enc_polar._enc_crc) 2107 2108 ######################################### 2109 # Public methods and properties 2110 ######################################### 2111 2112 @property 2113 def k_target(self): 2114 """Number of information bits including rate-matching.""" 2115 return self._k_target 2116 2117 @property 2118 def n_target(self): 2119 """Codeword length including rate-matching.""" 2120 return self._n_target 2121 2122 @property 2123 def k_polar(self): 2124 """Number of information bits of mother Polar code.""" 2125 return self._k_polar 2126 2127 @property 2128 def n_polar(self): 2129 """Codeword length of mother Polar code.""" 2130 return self._n_polar 2131 2132 @property 2133 def frozen_pos(self): 2134 """Frozen positions for Polar decoding.""" 2135 return self._frozen_pos 2136 2137 @property 2138 def info_pos(self): 2139 """Information bit positions for Polar encoding.""" 2140 return self._info_pos 2141 2142 @property 2143 def llr_max(self): 2144 """Maximum LLR value for internal calculations.""" 2145 return self._llr_max 2146 2147 @property 2148 def dec_type(self): 2149 """Decoder type used for decoding as str.""" 2150 return self._dec_type 2151 2152 @property 2153 def polar_dec(self): 2154 """Decoder instance used for decoding.""" 2155 return self._polar_dec 2156 2157 @property 2158 def output_dtype(self): 2159 """Output dtype of decoder.""" 2160 return self._output_dtype 2161 2162 ######################### 2163 # Utility methods 2164 ######################### 2165 2166 def _init_interleavers(self): 2167 """Initialize inverse interleaver patterns for rate-recovery.""" 2168 2169 # Channel interleaver 2170 ind_ch_int = self._enc_polar.channel_interleaver( 2171 np.arange(self._n_target)) 2172 self.ind_ch_int_inv = np.argsort(ind_ch_int) # Find inverse perm 2173 2174 # Sub-block interleaver 2175 ind_sub_int = self._enc_polar.subblock_interleaving( 2176 np.arange(self._n_polar)) 2177 self.ind_sub_int_inv = np.argsort(ind_sub_int) # Find inverse perm 2178 2179 # input bit interleaver 2180 if self._iil: 2181 self.ind_iil_inv = np.argsort(self._enc_polar.input_interleaver( 2182 np.arange(self._k_polar))) 2183 else: 2184 self.ind_iil_inv = None 2185 ######################### 2186 # Keras layer functions 2187 ######################### 2188 2189 def build(self, input_shape): 2190 """Build and check if shape of input is invalid.""" 2191 assert (input_shape[-1]==self._n_target), "Invalid input shape." 2192 assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.' 2193 2194 def call(self, inputs): 2195 """Polar decoding and rate-recovery for uplink 5G Polar codes. 2196 2197 Args: 2198 inputs (tf.float32): Tensor of shape `[...,n]` containing the 2199 channel logits/llr values. 2200 2201 Returns: 2202 `tf.float32`: Tensor of shape `[...,k]` containing 2203 hard-decided estimates of all ``k`` information bits. 2204 2205 Raises: 2206 ValueError: If ``inputs`` is not of shape `[..., n]` 2207 or `dtype` is not `output_dtype`. 2208 2209 InvalidArgumentError: When rank(``inputs``)<2. 2210 """ 2211 2212 tf.debugging.assert_type(inputs, self._output_dtype, 2213 "Invalid input dtype.") 2214 # internal calculations still in tf.float32 2215 inputs = tf.cast(inputs, tf.float32) 2216 2217 # Reshape inputs to [-1, n] 2218 tf.debugging.assert_greater(tf.rank(inputs), 1) 2219 input_shape = inputs.shape 2220 new_shape = [-1, self._n_target] 2221 llr_ch = tf.reshape(inputs, new_shape) 2222 2223 # Note: logits are not inverted here; this is done in the decoder itself 2224 2225 # 1.) Undo channel interleaving 2226 if self._bil: 2227 llr_deint = tf.gather(llr_ch, self.ind_ch_int_inv, axis=1) 2228 else: 2229 llr_deint = llr_ch 2230 2231 # 2.) Remove puncturing, shortening, repetition (see Sec. 5.4.1.2) 2232 # a) Puncturing: set LLRs to 0 2233 # b) Shortening: set LLRs to infinity 2234 # c) Repetition: combine LLRs 2235 if self._n_target >= self._n_polar: 2236 # Repetition coding 2237 # Add the last n_rep positions to the first llr positions 2238 n_rep = self._n_target - self._n_polar 2239 llr_1 = llr_deint[:,:n_rep] 2240 llr_2 = llr_deint[:,n_rep:self._n_polar] 2241 llr_3 = llr_deint[:,self._n_polar:] 2242 llr_dematched = tf.concat([llr_1+llr_3, llr_2], 1) 2243 else: 2244 if self._k_polar/self._n_target <= 7/16: 2245 # Puncturing 2246 # Append n_polar - n_target "zero" llrs to first positions 2247 llr_zero = tf.zeros([tf.shape(llr_deint)[0], 2248 self._n_polar-self._n_target]) 2249 llr_dematched = tf.concat([llr_zero, llr_deint], 1) 2250 else: 2251 # Shortening 2252 # Append n_polar - n_target "-infinity" llrs to last positions 2253 # Remark: we still operate with logits here, thus the neg. sign 2254 llr_infty = -self._llr_max * tf.ones([tf.shape(llr_deint)[0], 2255 self._n_polar-self._n_target]) 2256 llr_dematched = tf.concat([llr_deint, llr_infty], 1) 2257 2258 # 3.) Remove subblock interleaving 2259 llr_dec = tf.gather(llr_dematched, self.ind_sub_int_inv, axis=1) 2260 2261 # 4.) Run main decoder 2262 u_hat_crc = self._polar_dec(llr_dec) 2263 2264 # 5.) Shortening should be implicitly recovered by decoder 2265 2266 # 6.) Remove input bit interleaving for downlink channels only 2267 if self._iil: 2268 u_hat_crc = tf.gather(u_hat_crc, self.ind_iil_inv, axis=1) 2269 2270 # 7.) Evaluate or remove CRC (and PC) 2271 if self._return_crc_status: 2272 # for compatibility with SC/BP, a dedicated CRC decoder is 2273 # used here (instead of accessing the interal SCL) 2274 u_hat, crc_status = self._dec_crc(u_hat_crc) 2275 else: # just remove CRC bits 2276 u_hat = u_hat_crc[:,:-self._k_crc] 2277 2278 # And reconstruct input shape 2279 output_shape = input_shape.as_list() 2280 output_shape[-1] = self._k_target 2281 output_shape[0] = -1 # First dim can be dynamic (None) 2282 u_hat_reshape = tf.reshape(u_hat, output_shape) 2283 # and cast to output dtype 2284 u_hat_reshape = tf.cast(u_hat_reshape, dtype=self._output_dtype) 2285 2286 if self._return_crc_status: 2287 # reconstruct CRC shape 2288 output_shape.pop() # remove last dimension 2289 crc_status = tf.reshape(crc_status, output_shape) 2290 crc_status = tf.cast(crc_status, dtype=self._output_dtype) 2291 return u_hat_reshape, crc_status 2292 2293 else: 2294 return u_hat_reshape