decoding.py (55273B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers for channel decoding and utility functions.""" 6 7 import tensorflow as tf 8 import numpy as np 9 import scipy as sp # for sparse H matrix computations 10 from tensorflow.keras.layers import Layer 11 from sionna.fec.ldpc.encoding import LDPC5GEncoder 12 from sionna.fec.utils import llr2mi 13 import matplotlib.pyplot as plt 14 15 class LDPCBPDecoder(Layer): 16 # pylint: disable=line-too-long 17 r"""LDPCBPDecoder(pcm, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, num_iter=20, stateful=False,output_dtype=tf.float32, **kwargs) 18 19 Iterative belief propagation decoder for low-density parity-check (LDPC) 20 codes and other `codes on graphs`. 21 22 This class defines a generic belief propagation decoder for decoding 23 with arbitrary parity-check matrices. It can be used to iteratively 24 estimate/recover the transmitted codeword (or information bits) based on the 25 LLR-values of the received noisy codeword observation. 26 27 The decoder implements the flooding SPA algorithm [Ryan]_, i.e., all nodes 28 are updated in a parallel fashion. Different check node update functions are 29 available 30 31 (1) `boxplus` 32 33 .. math:: 34 y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right) 35 36 (2) `boxplus-phi` 37 38 .. math:: 39 y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}_(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right) 40 41 with :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x}{2}) \right)` 42 43 (3) `minsum` 44 45 .. math:: 46 \qquad y_{j \to i} = \alpha_{j \to i} \cdot {min}_{i' \in \mathcal{N}_(j) \setminus i} \left(|x_{i' \to j}|\right) 47 48 where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to 49 variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, 50 respectively. Further, :math:`\mathcal{N}_(j)` denotes all indices of 51 connected VNs to CN *j* and 52 53 .. math:: 54 \alpha_{j \to i} = \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{sign}(x_{i' \to j}) 55 56 is the sign of the outgoing message. For further details we refer to 57 [Ryan]_. 58 59 Note that for full 5G 3GPP NR compatibility, the correct puncturing and 60 shortening patterns must be applied (cf. [Richardson]_ for details), this 61 can be done by :class:`~sionna.fec.ldpc.decoding.LDPC5GEncoder` and 62 :class:`~sionna.fec.ldpc.decoding.LDPC5GDecoder`, respectively. 63 64 If required, the decoder can be made trainable and is fully differentiable 65 by following the concept of `weighted BP` [Nachmani]_ as shown in Fig. 1 66 leading to 67 68 .. math:: 69 y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} \left( \frac{\textcolor{red}{w_{i' \to j}} \cdot x_{i' \to j}}{2} \right) \right) 70 71 where :math:`w_{i \to j}` denotes the trainable weight of message :math:`x_{i \to j}`. 72 Please note that the training of some check node types may be not supported. 73 74 .. figure:: ../figures/weighted_bp.png 75 76 Fig. 1: Weighted BP as proposed in [Nachmani]_. 77 78 For numerical stability, the decoder applies LLR clipping of 79 +/- 20 to the input LLRs. 80 81 The class inherits from the Keras layer class and can be used as layer in a 82 Keras model. 83 84 Parameters 85 ---------- 86 pcm: ndarray 87 An ndarray of shape `[n-k, n]` defining the parity-check matrix 88 consisting only of `0` or `1` entries. Can be also of type `scipy. 89 sparse.csr_matrix` or `scipy.sparse.csc_matrix`. 90 91 trainable: bool 92 Defaults to False. If True, every outgoing variable node message is 93 scaled with a trainable scalar. 94 95 cn_type: str 96 A string defaults to '"boxplus-phi"'. One of 97 {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where 98 '"boxplus"' implements the single-parity-check APP decoding rule. 99 '"boxplus-phi"' implements the numerical more stable version of 100 boxplus [Ryan]_. 101 '"minsum"' implements the min-approximation of the CN 102 update rule [Ryan]_. 103 104 hard_out: bool 105 Defaults to True. If True, the decoder provides hard-decided 106 codeword bits instead of soft-values. 107 108 track_exit: bool 109 Defaults to False. If True, the decoder tracks EXIT 110 characteristics. Note that this requires the all-zero 111 CW as input. 112 113 num_iter: int 114 Defining the number of decoder iteration (no early stopping used at 115 the moment!). 116 117 stateful: bool 118 Defaults to False. If True, the internal VN messages ``msg_vn`` 119 from the last decoding iteration are returned, and ``msg_vn`` or 120 `None` needs to be given as a second input when calling the decoder. 121 This is required for iterative demapping and decoding. 122 123 output_dtype: tf.DType 124 Defaults to tf.float32. Defines the output datatype of the layer 125 (internal precision remains tf.float32). 126 127 Input 128 ----- 129 llrs_ch or (llrs_ch, msg_vn): 130 Tensor or Tuple (only required if ``stateful`` is True): 131 132 llrs_ch: [...,n], tf.float32 133 2+D tensor containing the channel logits/llr values. 134 135 msg_vn: None or RaggedTensor, tf.float32 136 Ragged tensor of VN messages. 137 Required only if ``stateful`` is True. 138 139 Output 140 ------ 141 : [...,n], tf.float32 142 2+D Tensor of same shape as ``inputs`` containing 143 bit-wise soft-estimates (or hard-decided bit-values) of all 144 codeword bits. 145 146 : RaggedTensor, tf.float32: 147 Tensor of VN messages. 148 Returned only if ``stateful`` is set to True. 149 150 Attributes 151 ---------- 152 pcm: ndarray 153 An ndarray of shape `[n-k, n]` defining the parity-check matrix 154 consisting only of `0` or `1` entries. Can be also of type `scipy. 155 sparse.csr_matrix` or `scipy.sparse.csc_matrix`. 156 157 num_cns: int 158 Defining the number of check nodes. 159 160 num_vns: int 161 Defining the number of variable nodes. 162 163 num_edges: int 164 Defining the total number of edges. 165 166 trainable: bool 167 If True, the decoder uses trainable weights. 168 169 _atanh_clip_value: float 170 Defining the internal clipping value before the atanh is applied 171 (relates to the CN update). 172 173 _cn_type: str 174 Defining the CN update function type. 175 176 _cn_update: 177 A function defining the CN update. 178 179 _hard_out: bool 180 If True, the decoder outputs hard-decided bits. 181 182 _cn_con: ndarray 183 An ndarray of shape `[num_edges]` defining all edges from check 184 node perspective. 185 186 _vn_con: ndarray 187 An ndarray of shape `[num_edges]` defining all edges from variable 188 node perspective. 189 190 _vn_mask_tf: tf.float32 191 A ragged Tensor of shape `[num_vns, None]` defining the incoming 192 message indices per VN. The second dimension is ragged and depends 193 on the node degree. 194 195 _cn_mask_tf: tf.float32 196 A ragged Tensor of shape `[num_cns, None]` defining the incoming 197 message indices per CN. The second dimension is ragged and depends 198 on the node degree. 199 200 _ind_cn: ndarray 201 An ndarray of shape `[num_edges]` defining the permutation index to 202 rearrange messages from variable into check node perspective. 203 204 _ind_cn_inv: ndarray 205 An ndarray of shape `[num_edges]` defining the permutation index to 206 rearrange messages from check into variable node perspective. 207 208 _vn_row_splits: ndarray 209 An ndarray of shape `[num_vns+1]` defining the row split positions 210 of a 1D vector consisting of all edges messages. Used to build a 211 ragged Tensor of incoming VN messages. 212 213 _cn_row_splits: ndarray 214 An ndarray of shape `[num_cns+1]` defining the row split positions 215 of a 1D vector consisting of all edges messages. Used to build a 216 ragged Tensor of incoming CN messages. 217 218 _edge_weights: tf.float32 219 A Tensor of shape `[num_edges]` defining a (trainable) weight per 220 outgoing VN message. 221 222 Raises: 223 ValueError 224 If the shape of ``pcm`` is invalid or contains other values than 225 `0` or `1` or dtype is not `tf.float32`. 226 227 ValueError 228 If ``num_iter`` is not an integer greater (or equal) `0`. 229 230 ValueError 231 If ``output_dtype`` is not 232 {tf.float16, tf.float32, tf.float64}. 233 234 ValueError 235 If ``inputs`` is not of shape `[batch_size, n]`. 236 237 InvalidArgumentError 238 When rank(``inputs``)<2. 239 Note 240 ---- 241 As decoding input logits 242 :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are 243 assumed for compatibility with the learning framework, but internally 244 log-likelihood ratios (LLRs) with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. 245 246 The decoder is not (particularly) optimized for quasi-cyclic (QC) LDPC 247 codes and, thus, supports arbitrary parity-check matrices. 248 249 The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to 250 account for arbitrary node degrees. To avoid a performance degradation 251 caused by a severe indexing overhead, the batch-dimension is shifted to 252 the last dimension during decoding. 253 254 If the decoder is made trainable [Nachmani]_, for performance 255 improvements only variable to check node messages are scaled as the VN 256 operation is linear and, thus, would not increase the expressive power 257 of the weights. 258 259 """ 260 261 def __init__(self, 262 pcm, 263 trainable=False, 264 cn_type='boxplus-phi', 265 hard_out=True, 266 track_exit=False, 267 num_iter=20, 268 stateful=False, 269 output_dtype=tf.float32, 270 **kwargs): 271 272 super().__init__(dtype=output_dtype, **kwargs) 273 274 assert isinstance(trainable, bool), 'trainable must be bool.' 275 assert isinstance(hard_out, bool), 'hard_out must be bool.' 276 assert isinstance(track_exit, bool), 'track_exit must be bool.' 277 assert isinstance(cn_type, str) , 'cn_type must be str.' 278 assert isinstance(num_iter, int), 'num_iter must be int.' 279 assert num_iter>=0, 'num_iter cannot be negative.' 280 assert isinstance(stateful, bool), 'stateful must be bool.' 281 assert isinstance(output_dtype, tf.DType), \ 282 'output_dtype must be tf.Dtype.' 283 284 if isinstance(pcm, np.ndarray): 285 assert np.array_equal(pcm, pcm.astype(bool)), 'PC matrix \ 286 must be binary.' 287 elif isinstance(pcm, sp.sparse.csr_matrix): 288 assert np.array_equal(pcm.data, pcm.data.astype(bool)), \ 289 'PC matrix must be binary.' 290 elif isinstance(pcm, sp.sparse.csc_matrix): 291 assert np.array_equal(pcm.data, pcm.data.astype(bool)), \ 292 'PC matrix must be binary.' 293 else: 294 raise TypeError("Unsupported dtype of pcm.") 295 296 if output_dtype not in (tf.float16, tf.float32, tf.float64): 297 raise ValueError( 298 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 299 300 if output_dtype is not tf.float32: 301 print('Note: decoder uses tf.float32 for internal calculations.') 302 303 # init decoder parameters 304 self._pcm = pcm 305 self._trainable = trainable 306 self._cn_type = cn_type 307 self._hard_out = hard_out 308 self._track_exit = track_exit 309 self._num_iter = tf.constant(num_iter, dtype=tf.int32) 310 self._stateful = stateful 311 self._output_dtype = output_dtype 312 313 # clipping value for the atanh function is applied (tf.float32 is used) 314 self._atanh_clip_value = 1 - 1e-7 315 # internal value for llr clipping 316 self._llr_max = tf.constant(20., tf.float32) 317 318 # init code parameters 319 self._num_cns = pcm.shape[0] # total number of check nodes 320 self._num_vns = pcm.shape[1] # total number of variable nodes 321 322 # make pcm sparse first if ndarray is provided 323 if isinstance(pcm, np.ndarray): 324 pcm = sp.sparse.csr_matrix(pcm) 325 326 # find all edges from variable and check node perspective 327 self._cn_con, self._vn_con, _ = sp.sparse.find(pcm) 328 329 # sort indices explicitly, as scipy.sparse.find changed from column to 330 # row sorting in scipy>=1.11 331 idx = np.argsort(self._vn_con) 332 self._cn_con = self._cn_con[idx] 333 self._vn_con = self._vn_con[idx] 334 335 # number of edges equals number of non-zero elements in the 336 # parity-check matrix 337 self._num_edges = len(self._vn_con) 338 339 # permutation index to rearrange messages into check node perspective 340 self._ind_cn = np.argsort(self._cn_con) 341 342 # inverse permutation index to rearrange messages back into variable 343 # node perspective 344 self._ind_cn_inv = np.argsort(self._ind_cn) 345 346 # generate row masks (array of integers defining the row split pos.) 347 self._vn_row_splits = self._gen_node_mask_row(self._vn_con) 348 self._cn_row_splits = self._gen_node_mask_row( 349 self._cn_con[self._ind_cn]) 350 # pre-load the CN function for performance reasons 351 if self._cn_type=='boxplus': 352 # check node update using the tanh function 353 self._cn_update = self._cn_update_tanh 354 elif self._cn_type=='boxplus-phi': 355 # check node update using the "_phi" function 356 self._cn_update = self._cn_update_phi 357 elif self._cn_type=='minsum': 358 # check node update using the min-sum approximation 359 self._cn_update = self._cn_update_minsum 360 else: 361 raise ValueError('Unknown node type.') 362 363 # init trainable weights if needed 364 self._has_weights = False # indicates if trainable weights exist 365 if self._trainable: 366 self._has_weights = True 367 self._edge_weights = tf.Variable(tf.ones(self._num_edges), 368 trainable=self._trainable, 369 dtype=tf.float32) 370 371 # track mutual information during decoding 372 self._ie_c = 0 373 self._ie_v = 0 374 375 ######################################### 376 # Public methods and properties 377 ######################################### 378 379 @property 380 def pcm(self): 381 """Parity-check matrix of LDPC code.""" 382 return self._pcm 383 384 @property 385 def num_cns(self): 386 """Number of check nodes.""" 387 return self._num_cns 388 389 @property 390 def num_vns(self): 391 """Number of variable nodes.""" 392 return self._num_vns 393 394 @property 395 def num_edges(self): 396 """Number of edges in decoding graph.""" 397 return self._num_edges 398 399 @property 400 def has_weights(self): 401 """Indicates if decoder has trainable weights.""" 402 return self._has_weights 403 404 @property 405 def edge_weights(self): 406 """Trainable weights of the BP decoder.""" 407 if not self._has_weights: 408 return [] 409 else: 410 return self._edge_weights 411 412 @property 413 def output_dtype(self): 414 """Output dtype of decoder.""" 415 return self._output_dtype 416 417 @property 418 def ie_c(self): 419 "Extrinsic mutual information at check node." 420 return self._ie_c 421 422 @property 423 def ie_v(self): 424 "Extrinsic mutual information at variable node." 425 return self._ie_v 426 427 @property 428 def num_iter(self): 429 "Number of decoding iterations." 430 return self._num_iter 431 432 @num_iter.setter 433 def num_iter(self, num_iter): 434 "Number of decoding iterations." 435 assert isinstance(num_iter, int), 'num_iter must be int.' 436 assert num_iter>=0, 'num_iter cannot be negative.' 437 self._num_iter = tf.constant(num_iter, dtype=tf.int32) 438 439 @property 440 def llr_max(self): 441 """Max LLR value used for internal calculations and rate-matching.""" 442 return self._llr_max 443 444 @llr_max.setter 445 def llr_max(self, value): 446 """Max LLR value used for internal calculations and rate-matching.""" 447 assert value>=0, 'llr_max cannot be negative.' 448 self._llr_max = tf.cast(value, dtype=tf.float32) 449 450 def show_weights(self, size=7): 451 """Show histogram of trainable weights. 452 453 Input 454 ----- 455 size: float 456 Figure size of the matplotlib figure. 457 458 """ 459 # only plot if weights exist 460 if self._has_weights: 461 weights = self._edge_weights.numpy() 462 463 plt.figure(figsize=(size,size)) 464 plt.hist(weights, density=True, bins=20, align='mid') 465 plt.xlabel('weight value') 466 plt.ylabel('density') 467 plt.grid(True, which='both', axis='both') 468 plt.title('Weight Distribution') 469 else: 470 print("No weights to show.") 471 472 ######################### 473 # Utility methods 474 ######################### 475 476 def _gen_node_mask(self, con): 477 """ Generates internal node masks indicating which msg index belongs 478 to which node index. 479 """ 480 ind = np.argsort(con) 481 con = con[ind] 482 483 node_mask = [] 484 485 cur_node = 0 486 cur_mask = [] 487 for i in range(self._num_edges): 488 if con[i] == cur_node: 489 cur_mask.append(ind[i]) 490 else: 491 node_mask.append(cur_mask) 492 cur_mask = [ind[i]] 493 cur_node += 1 494 node_mask.append(cur_mask) 495 return node_mask 496 497 def _gen_node_mask_row(self, con): 498 """ Defining the row split positions of a 1D vector consisting of all 499 edges messages. 500 501 Used to build a ragged Tensor of incoming node messages. 502 """ 503 node_mask = [0] # the first element indicates the first node index (=0) 504 505 cur_node = 0 506 for i in range(self._num_edges): 507 if con[i] != cur_node: 508 node_mask.append(i) 509 cur_node += 1 510 node_mask.append(self._num_edges) # last element must be the number of 511 # elements (delimiter) 512 return node_mask 513 514 def _vn_update(self, msg, llr_ch): 515 """ Variable node update function. 516 517 This function implements the (extrinsic) variable node update 518 function. It takes the sum over all incoming messages ``msg`` excluding 519 the intrinsic (= outgoing) message itself. 520 521 Additionally, the channel LLR ``llr_ch`` is added to each message. 522 """ 523 # aggregate all incoming messages per node 524 x = tf.reduce_sum(msg, axis=1) 525 x = tf.add(x, llr_ch) 526 527 # TF2.9 does not support XLA for the addition of ragged tensors 528 # the following code provides a workaround that supports XLA 529 530 # subtract extrinsic message from node value 531 # x = tf.expand_dims(x, axis=1) 532 # x = tf.add(-msg, x) 533 x = tf.ragged.map_flat_values(lambda x, y, row_ind : 534 x + tf.gather(y, row_ind), 535 -1.*msg, 536 x, 537 msg.value_rowids()) 538 return x 539 540 def _where_ragged(self, msg): 541 """Helper to replace 0 elements from ragged tensor (called with 542 map_flat_values).""" 543 return tf.where(tf.equal(msg, 0), tf.ones_like(msg) * 1e-12, msg) 544 545 def _where_ragged_inv(self, msg): 546 """Helper to replace small elements from ragged tensor (called with 547 map_flat_values) with exact `0`.""" 548 msg_mod = tf.where(tf.less(tf.abs(msg), 1e-7), 549 tf.zeros_like(msg), 550 msg) 551 return msg_mod 552 553 def _cn_update_tanh(self, msg): 554 """Check node update function implementing the exact boxplus operation. 555 556 This function implements the (extrinsic) check node update 557 function. It calculates the boxplus function over all incoming messages 558 "msg" excluding the intrinsic (=outgoing) message itself. 559 The exact boxplus function is implemented by using the tanh function. 560 561 The input is expected to be a ragged Tensor of shape 562 `[num_cns, None, batch_size]`. 563 564 Note that for numerical stability clipping is applied. 565 """ 566 567 msg = msg / 2 568 # tanh is not overloaded for ragged tensors 569 msg = tf.ragged.map_flat_values(tf.tanh, msg) # tanh is not overloaded 570 571 # for ragged tensors; map to flat tensor first 572 msg = tf.ragged.map_flat_values(self._where_ragged, msg) 573 574 msg_prod = tf.reduce_prod(msg, axis=1) 575 576 # TF2.9 does not support XLA for the multiplication of ragged tensors 577 # the following code provides a workaround that supports XLA 578 579 # ^-1 to avoid division 580 # Note this is (potentially) numerically unstable 581 # msg = msg**-1 * tf.expand_dims(msg_prod, axis=1) # remove own edge 582 583 msg = tf.ragged.map_flat_values(lambda x, y, row_ind : 584 x * tf.gather(y, row_ind), 585 msg**-1, 586 msg_prod, 587 msg.value_rowids()) 588 589 # Overwrite small (numerical zeros) message values with exact zero 590 # these are introduced by the previous "_where_ragged" operation 591 # this is required to keep the product stable (cf. _phi_update for log 592 # sum implementation) 593 msg = tf.ragged.map_flat_values(self._where_ragged_inv, msg) 594 595 msg = tf.clip_by_value(msg, 596 clip_value_min=-self._atanh_clip_value, 597 clip_value_max=self._atanh_clip_value) 598 599 # atanh is not overloaded for ragged tensors 600 msg = 2 * tf.ragged.map_flat_values(tf.atanh, msg) 601 return msg 602 603 def _phi(self, x): 604 """Helper function for the check node update. 605 606 This function implements the (element-wise) `"_phi"` function as defined 607 in [Ryan]_. 608 """ 609 # the clipping values are optimized for tf.float32 610 x = tf.clip_by_value(x, clip_value_min=8.5e-8, clip_value_max=16.635532) 611 return tf.math.log(tf.math.exp(x)+1) - tf.math.log(tf.math.exp(x)-1) 612 613 def _cn_update_phi(self, msg): 614 """Check node update function implementing the exact boxplus operation. 615 616 This function implements the (extrinsic) check node update function 617 based on the numerically more stable `"_phi"` function (cf. [Ryan]_). 618 It calculates the boxplus function over all incoming messages ``msg`` 619 excluding the intrinsic (=outgoing) message itself. 620 The exact boxplus function is implemented by using the `"_phi"` function 621 as in [Ryan]_. 622 623 The input is expected to be a ragged Tensor of shape 624 `[num_cns, None, batch_size]`. 625 626 Note that for numerical stability clipping is applied. 627 """ 628 629 sign_val = tf.sign(msg) 630 631 # TF2.14 does not support XLA for tf.where and ragged tensors in 632 # CPU mode. The following code provides a workaround that supports XLA 633 # sign_val = tf.where(tf.equal(sign_val, 0), 634 # tf.ones_like(sign_val), 635 # sign_val) 636 sign_val = tf.ragged.map_flat_values(lambda x : 637 tf.where(tf.equal(x, 0), 638 tf.ones_like(x),x), 639 sign_val) 640 641 sign_node = tf.reduce_prod(sign_val, axis=1) 642 643 # TF2.9 does not support XLA for the multiplication of ragged tensors 644 # the following code provides a workaround that supports XLA 645 646 # sign_val = sign_val * tf.expand_dims(sign_node, axis=1) 647 sign_val = tf.ragged.map_flat_values(lambda x, y, row_ind : 648 x * tf.gather(y, row_ind), 649 sign_val, 650 sign_node, 651 sign_val.value_rowids()) 652 653 msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign 654 655 # apply _phi element-wise (does not support ragged Tensors) 656 msg = tf.ragged.map_flat_values(self._phi, msg) 657 msg_sum = tf.reduce_sum(msg, axis=1) 658 659 # TF2.9 does not support XLA for the addition of ragged tensors 660 # the following code provides a workaround that supports XLA 661 662 # msg = tf.add( -msg, tf.expand_dims(msg_sum, axis=1)) # remove own edge 663 msg = tf.ragged.map_flat_values(lambda x, y, row_ind : 664 x + tf.gather(y, row_ind), 665 -1.*msg, 666 msg_sum, 667 msg.value_rowids()) 668 669 # apply _phi element-wise (does not support ragged Tensors) 670 msg = self._stop_ragged_gradient(sign_val) * tf.ragged.map_flat_values( 671 self._phi, msg) 672 return msg 673 674 def _stop_ragged_gradient(self, rt): 675 """Helper function as TF 2.5 does not support ragged gradient 676 stopping""" 677 return rt.with_flat_values(tf.stop_gradient(rt.flat_values)) 678 679 def _sign_val_minsum(self, msg): 680 """Helper to replace find sign-value during min-sum decoding. 681 Must be called with `map_flat_values`.""" 682 683 sign_val = tf.sign(msg) 684 sign_val = tf.where(tf.equal(sign_val, 0), 685 tf.ones_like(sign_val), 686 sign_val) 687 return sign_val 688 689 def _cn_update_minsum(self, msg): 690 """ Check node update function implementing the min-sum approximation. 691 692 This function approximates the (extrinsic) check node update 693 function based on the min-sum approximation (cf. [Ryan]_). 694 It calculates the "extrinsic" min function over all incoming messages 695 ``msg`` excluding the intrinsic (=outgoing) message itself. 696 697 The input is expected to be a ragged Tensor of shape 698 `[num_vns, None, batch_size]`. 699 """ 700 701 # a constant used to overwrite the first min 702 LARGE_VAL = 10000. # pylint: disable=invalid-name 703 704 # clip values for numerical stability 705 msg = tf.clip_by_value(msg, 706 clip_value_min=-self._llr_max, 707 clip_value_max=self._llr_max) 708 709 # calculate sign of outgoing msg and the node 710 sign_val = tf.ragged.map_flat_values(self._sign_val_minsum, msg) 711 sign_node = tf.reduce_prod(sign_val, axis=1) 712 713 # TF2.9 does not support XLA for the multiplication of ragged tensors 714 # the following code provides a workaround that supports XLA 715 716 # sign_val = self._stop_ragged_gradient(sign_val) \ 717 # * tf.expand_dims(sign_node, axis=1) 718 sign_val = tf.ragged.map_flat_values( 719 lambda x, y, row_ind: 720 tf.multiply(x, tf.gather(y, row_ind)), 721 self._stop_ragged_gradient(sign_val), 722 sign_node, 723 sign_val.value_rowids()) 724 725 # remove sign from messages 726 msg = tf.ragged.map_flat_values(tf.abs, msg) 727 728 # Calculate the extrinsic minimum per CN, i.e., for each message of 729 # index i, find the smallest and the second smallest value. 730 # However, in some cases the second smallest value may equal the 731 # smallest value (multiplicity of mins). 732 # Please note that this needs to be applied to raggedTensors, e.g., 733 # tf.top_k() is currently not supported and all ops must support graph 734 # and XLA mode. 735 736 # find min_value per node 737 min_val = tf.reduce_min(msg, axis=1, keepdims=True) 738 739 # TF2.9 does not support XLA for the subtraction of ragged tensors 740 # the following code provides a workaround that supports XLA 741 742 # and subtract min; the new array contains zero at the min positions 743 # benefits from broadcasting; all other values are positive 744 msg_min1 = tf.ragged.map_flat_values(lambda x, y, row_ind: 745 x - tf.gather(y, row_ind), 746 msg, 747 tf.squeeze(min_val, axis=1), 748 msg.value_rowids()) 749 750 # replace 0 (=min positions) with large value to ignore it for further 751 # min calculations 752 msg = tf.ragged.map_flat_values( 753 lambda x: tf.where(tf.equal(x, 0), LARGE_VAL, x), 754 msg_min1) 755 756 # find the second smallest element (we add min_val as this has been 757 # subtracted before) 758 min_val_2 = tf.reduce_min(msg, axis=1, keepdims=True) + min_val 759 760 # Detect duplicated minima (i.e., min_val occurs at two incoming 761 # messages). As the LLRs per node are <LLR_MAX and we have 762 # replace at least 1 position (position with message "min_val") by 763 # LARGE_VAL, it holds for the sum < LARGE_VAL + node_degree*LLR_MAX. 764 # If the sum > 2*LARGE_VAL, the multiplicity of the min is at least 2. 765 node_sum = tf.reduce_sum(msg, axis=1, keepdims=True) - (2*LARGE_VAL-1.) 766 # indicator that duplicated min was detected (per node) 767 double_min = 0.5*(1-tf.sign(node_sum)) 768 769 # if a duplicate min occurred, both edges must have min_val, otherwise 770 # the second smallest value is taken 771 min_val_e = (1-double_min) * min_val + (double_min) * min_val_2 772 773 # replace all values with min_val except the position where the min 774 # occurred (=extrinsic min). 775 776 # no XLA support for TF 2.15 777 # msg_e = tf.where(msg==LARGE_VAL, min_val_e, min_val) 778 779 min_1 = tf.squeeze(tf.gather(min_val, msg.value_rowids()), axis=1) 780 min_e = tf.squeeze(tf.gather(min_val_e, msg.value_rowids()), axis=1) 781 msg_e = tf.ragged.map_flat_values( 782 lambda x: tf.where(x==LARGE_VAL, min_e, min_1), 783 msg) 784 785 # it seems like tf.where does not set the shape of tf.ragged properly 786 # we need to ensure the shape manually 787 msg_e = tf.ragged.map_flat_values( 788 lambda x: tf.ensure_shape(x, msg.flat_values.shape), 789 msg_e) 790 791 # TF2.9 does not support XLA for the multiplication of ragged tensors 792 # the following code provides a workaround that supports XLA 793 794 # and apply sign 795 #msg = sign_val * msg_e 796 msg = tf.ragged.map_flat_values(tf.multiply, 797 sign_val, 798 msg_e) 799 800 return msg 801 802 def _mult_weights(self, x): 803 """Multiply messages with trainable weights for weighted BP.""" 804 # transpose for simpler broadcasting of training variables 805 x = tf.transpose(x, (1, 0)) 806 x = tf.math.multiply(x, self._edge_weights) 807 x = tf.transpose(x, (1, 0)) 808 return x 809 810 ######################### 811 # Keras layer functions 812 ######################### 813 814 def build(self, input_shape): 815 # Raise AssertionError if shape of x is invalid 816 if self._stateful: 817 assert(len(input_shape)==2), \ 818 "For stateful decoding, a tuple of two inputs is expected." 819 input_shape = input_shape[0] 820 821 assert (input_shape[-1]==self._num_vns), \ 822 'Last dimension must be of length n.' 823 assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' 824 825 def call(self, inputs): 826 """Iterative BP decoding function. 827 828 This function performs ``num_iter`` belief propagation decoding 829 iterations and returns the estimated codeword. 830 831 Args: 832 llr_ch or (llr_ch, msg_vn): 833 834 llr_ch (tf.float32): Tensor of shape `[...,n]` containing the 835 channel logits/llr values. 836 837 msg_vn (tf.float32) : Ragged tensor containing the VN 838 messages, or None. Required if ``stateful`` is set to True. 839 840 Returns: 841 `tf.float32`: Tensor of shape `[...,n]` containing 842 bit-wise soft-estimates (or hard-decided bit-values) of all 843 codeword bits. 844 845 Raises: 846 ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 847 848 InvalidArgumentError: When rank(``inputs``)<2. 849 """ 850 851 # Extract inputs 852 if self._stateful: 853 llr_ch, msg_vn = inputs 854 else: 855 llr_ch = inputs 856 857 tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 858 859 # internal calculations still in tf.float32 860 llr_ch = tf.cast(llr_ch, tf.float32) 861 862 # clip llrs for numerical stability 863 llr_ch = tf.clip_by_value(llr_ch, 864 clip_value_min=-self._llr_max, 865 clip_value_max=self._llr_max) 866 867 # last dim must be of length n 868 tf.debugging.assert_equal(tf.shape(llr_ch)[-1], 869 self._num_vns, 870 'Last dimension must be of length n.') 871 872 llr_ch_shape = llr_ch.get_shape().as_list() 873 new_shape = [-1, self._num_vns] 874 llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 875 876 # must be done during call, as XLA fails otherwise due to ragged 877 # indices placed on the CPU device. 878 # create permutation index from cn perspective 879 self._cn_mask_tf = tf.ragged.constant(self._gen_node_mask(self._cn_con), 880 row_splits_dtype=tf.int32) 881 882 # batch dimension is last dimension due to ragged tensor representation 883 llr_ch = tf.transpose(llr_ch_reshaped, (1,0)) 884 885 llr_ch = -1. * llr_ch # logits are converted into "true" llrs 886 887 # init internal decoder state if not explicitly 888 # provided (e.g., required to restore decoder state for iterative 889 # detection and decoding) 890 # load internal state from previous iteration 891 # required for iterative det./dec. 892 if not self._stateful or msg_vn is None: 893 msg_shape = tf.stack([tf.constant(self._num_edges), 894 tf.shape(llr_ch)[1]], 895 axis=0) 896 msg_vn = tf.zeros(msg_shape, dtype=tf.float32) 897 else: 898 msg_vn = msg_vn.flat_values 899 900 # track exit decoding trajectory; requires all-zero cw? 901 if self._track_exit: 902 self._ie_c = tf.zeros(self._num_iter+1) 903 self._ie_v = tf.zeros(self._num_iter+1) 904 905 # perform one decoding iteration 906 # Remark: msg_vn cannot be ragged as input for tf.while_loop as 907 # otherwise XLA will not be supported (with TF 2.5) 908 def dec_iter(llr_ch, msg_vn, it): 909 it += 1 910 911 msg_vn = tf.RaggedTensor.from_row_splits( 912 values=msg_vn, 913 row_splits=tf.constant(self._vn_row_splits, tf.int32)) 914 # variable node update 915 msg_vn = self._vn_update(msg_vn, llr_ch) 916 917 # track exit decoding trajectory; requires all-zero cw 918 if self._track_exit: 919 # neg values as different llr def is expected 920 mi = llr2mi(-1. * msg_vn.flat_values) 921 self._ie_v = tf.tensor_scatter_nd_add(self._ie_v, 922 tf.reshape(it, (1, 1)), 923 tf.reshape(mi, (1))) 924 925 # scale outgoing vn messages (weighted BP); only if activated 926 if self._has_weights: 927 msg_vn = tf.ragged.map_flat_values(self._mult_weights, 928 msg_vn) 929 # permute edges into CN perspective 930 msg_cn = tf.gather(msg_vn.flat_values, self._cn_mask_tf, axis=None) 931 932 # check node update using the pre-defined function 933 msg_cn = self._cn_update(msg_cn) 934 935 # track exit decoding trajectory; requires all-zero cw? 936 if self._track_exit: 937 # neg values as different llr def is expected 938 mi = llr2mi(-1.*msg_cn.flat_values) 939 # update pos i+1 such that first iter is stored as 0 940 self._ie_c = tf.tensor_scatter_nd_add(self._ie_c, 941 tf.reshape(it, (1, 1)), 942 tf.reshape(mi, (1))) 943 944 # re-permute edges to variable node perspective 945 msg_vn = tf.gather(msg_cn.flat_values, self._ind_cn_inv, axis=None) 946 return llr_ch, msg_vn, it 947 948 # stopping condition (required for tf.while_loop) 949 def dec_stop(llr_ch, msg_vn, it): # pylint: disable=W0613 950 return tf.less(it, self._num_iter) 951 952 # start decoding iterations 953 it = tf.constant(0) 954 # maximum_iterations required for XLA 955 _, msg_vn, _ = tf.while_loop(dec_stop, 956 dec_iter, 957 (llr_ch, msg_vn, it), 958 parallel_iterations=1, 959 maximum_iterations=self._num_iter) 960 961 962 # raggedTensor for final marginalization 963 msg_vn = tf.RaggedTensor.from_row_splits( 964 values=msg_vn, 965 row_splits=tf.constant(self._vn_row_splits, tf.int32)) 966 967 # marginalize and remove ragged Tensor 968 x_hat = tf.add(llr_ch, tf.reduce_sum(msg_vn, axis=1)) 969 970 # restore batch dimension to first dimension 971 x_hat = tf.transpose(x_hat, (1,0)) 972 973 x_hat = -1. * x_hat # convert llrs back into logits 974 975 if self._hard_out: # hard decide decoder output if required 976 x_hat = tf.cast(tf.less(0.0, x_hat), self._output_dtype) 977 978 # Reshape c_short so that it matches the original input dimensions 979 output_shape = llr_ch_shape 980 output_shape[0] = -1 # overwrite batch dim (can be None in Keras) 981 982 x_reshaped = tf.reshape(x_hat, output_shape) 983 984 # cast output to output_dtype 985 x_out = tf.cast(x_reshaped, self._output_dtype) 986 987 if not self._stateful: 988 return x_out 989 else: 990 return x_out, msg_vn 991 992 class LDPC5GDecoder(LDPCBPDecoder): 993 # pylint: disable=line-too-long 994 r"""LDPC5GDecoder(encoder, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, return_infobits=True, prune_pcm=True, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs) 995 996 (Iterative) belief propagation decoder for 5G NR LDPC codes. 997 998 Inherits from :class:`~sionna.fec.ldpc.decoding.LDPCBPDecoder` and provides 999 a wrapper for 5G compatibility, i.e., automatically handles puncturing and 1000 shortening according to [3GPPTS38212_LDPC]_. 1001 1002 Note that for full 5G 3GPP NR compatibility, the correct puncturing and 1003 shortening patterns must be applied and, thus, the encoder object is 1004 required as input. 1005 1006 If required the decoder can be made trainable and is differentiable 1007 (the training of some check node types may be not supported) following the 1008 concept of "weighted BP" [Nachmani]_. 1009 1010 For numerical stability, the decoder applies LLR clipping of 1011 +/- 20 to the input LLRs. 1012 1013 The class inherits from the Keras layer class and can be used as layer in a 1014 Keras model. 1015 1016 Parameters 1017 ---------- 1018 encoder: LDPC5GEncoder 1019 An instance of :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder` 1020 containing the correct code parameters. 1021 1022 trainable: bool 1023 Defaults to False. If True, every outgoing variable node message is 1024 scaled with a trainable scalar. 1025 1026 cn_type: str 1027 A string defaults to '"boxplus-phi"'. One of 1028 {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where 1029 '"boxplus"' implements the single-parity-check APP decoding rule. 1030 '"boxplus-phi"' implements the numerical more stable version of 1031 boxplus [Ryan]_. 1032 '"minsum"' implements the min-approximation of the CN 1033 update rule [Ryan]_. 1034 1035 hard_out: bool 1036 Defaults to True. If True, the decoder provides hard-decided 1037 codeword bits instead of soft-values. 1038 1039 track_exit: bool 1040 Defaults to False. If True, the decoder tracks EXIT characteristics. 1041 Note that this requires the all-zero CW as input. 1042 1043 return_infobits: bool 1044 Defaults to True. If True, only the `k` info bits (soft or 1045 hard-decided) are returned. Otherwise all `n` positions are 1046 returned. 1047 1048 prune_pcm: bool 1049 Defaults to True. If True, all punctured degree-1 VNs and 1050 connected check nodes are removed from the decoding graph (see 1051 [Cammerer]_ for details). Besides numerical differences, this should 1052 yield the same decoding result but improved the decoding throughput 1053 and reduces the memory footprint. 1054 1055 num_iter: int 1056 Defining the number of decoder iteration (no early stopping used at 1057 the moment!). 1058 1059 stateful: bool 1060 Defaults to False. If True, the internal VN messages ``msg_vn`` 1061 from the last decoding iteration are returned, and ``msg_vn`` or 1062 `None` needs to be given as a second input when calling the decoder. 1063 This is required for iterative demapping and decoding. 1064 1065 output_dtype: tf.DType 1066 Defaults to tf.float32. Defines the output datatype of the layer 1067 (internal precision remains tf.float32). 1068 1069 Input 1070 ----- 1071 llrs_ch or (llrs_ch, msg_vn): 1072 Tensor or Tuple (only required if ``stateful`` is True): 1073 1074 llrs_ch: [...,n], tf.float32 1075 2+D tensor containing the channel logits/llr values. 1076 1077 msg_vn: None or RaggedTensor, tf.float32 1078 Ragged tensor of VN messages. 1079 Required only if ``stateful`` is True. 1080 1081 Output 1082 ------ 1083 : [...,n] or [...,k], tf.float32 1084 2+D Tensor of same shape as ``inputs`` containing 1085 bit-wise soft-estimates (or hard-decided bit-values) of all 1086 codeword bits. If ``return_infobits`` is True, only the `k` 1087 information bits are returned. 1088 1089 : RaggedTensor, tf.float32: 1090 Tensor of VN messages. 1091 Returned only if ``stateful`` is set to True. 1092 Raises 1093 ------ 1094 ValueError 1095 If the shape of ``pcm`` is invalid or contains other 1096 values than `0` or `1`. 1097 1098 AssertionError 1099 If ``trainable`` is not `bool`. 1100 1101 AssertionError 1102 If ``track_exit`` is not `bool`. 1103 1104 AssertionError 1105 If ``hard_out`` is not `bool`. 1106 1107 AssertionError 1108 If ``return_infobits`` is not `bool`. 1109 1110 AssertionError 1111 If ``encoder`` is not an instance of 1112 :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder`. 1113 1114 ValueError 1115 If ``output_dtype`` is not {tf.float16, tf.float32, tf. 1116 float64}. 1117 1118 ValueError 1119 If ``inputs`` is not of shape `[batch_size, n]`. 1120 1121 ValueError 1122 If ``num_iter`` is not an integer greater (or equal) `0`. 1123 1124 InvalidArgumentError 1125 When rank(``inputs``)<2. 1126 1127 Note 1128 ---- 1129 As decoding input logits 1130 :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for 1131 compatibility with the learning framework, but 1132 internally llrs with definition 1133 :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. 1134 1135 The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC 1136 codes and, thus, supports arbitrary parity-check matrices. 1137 1138 The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to 1139 account for arbitrary node degrees. To avoid a performance degradation 1140 caused by a severe indexing overhead, the batch-dimension is shifted to 1141 the last dimension during decoding. 1142 1143 If the decoder is made trainable [Nachmani]_, for performance 1144 improvements only variable to check node messages are scaled as the VN 1145 operation is linear and, thus, would not increase the expressive power 1146 of the weights. 1147 """ 1148 1149 def __init__(self, 1150 encoder, 1151 trainable=False, 1152 cn_type='boxplus-phi', 1153 hard_out=True, 1154 track_exit=False, 1155 return_infobits=True, 1156 prune_pcm=True, 1157 num_iter=20, 1158 stateful=False, 1159 output_dtype=tf.float32, 1160 **kwargs): 1161 1162 # needs the 5G Encoder to access all 5G parameters 1163 assert isinstance(encoder, LDPC5GEncoder), 'encoder must \ 1164 be of class LDPC5GEncoder.' 1165 self._encoder = encoder 1166 pcm = encoder.pcm 1167 1168 assert isinstance(return_infobits, bool), 'return_info must be bool.' 1169 self._return_infobits = return_infobits 1170 1171 assert isinstance(output_dtype, tf.DType), \ 1172 'output_dtype must be tf.DType.' 1173 if output_dtype not in (tf.float16, tf.float32, tf.float64): 1174 raise ValueError( 1175 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 1176 self._output_dtype = output_dtype 1177 1178 assert isinstance(stateful, bool), 'stateful must be bool.' 1179 self._stateful = stateful 1180 1181 assert isinstance(prune_pcm, bool), 'prune_pcm must be bool.' 1182 # prune punctured degree-1 VNs and connected CNs. A punctured 1183 # VN-1 node will always "send" llr=0 to the connected CN. Thus, this 1184 # CN will only send 0 messages to all other VNs, i.e., does not 1185 # contribute to the decoding process. 1186 self._prune_pcm = prune_pcm 1187 if prune_pcm: 1188 # find index of first position with only degree-1 VN 1189 dv = np.sum(pcm, axis=0) # VN degree 1190 last_pos = encoder._n_ldpc 1191 for idx in range(encoder._n_ldpc-1, 0, -1): 1192 if dv[0, idx]==1: 1193 last_pos = idx 1194 else: 1195 break 1196 # number of filler bits 1197 k_filler = self.encoder.k_ldpc - self.encoder.k 1198 # number of punctured bits 1199 nb_punc_bits = ((self.encoder.n_ldpc - k_filler) 1200 - self.encoder.n - 2*self.encoder.z) 1201 # effective codeword length after pruning of vn-1 nodes 1202 self._n_pruned = np.max((last_pos, encoder._n_ldpc - nb_punc_bits)) 1203 self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned 1204 # remove last CNs and VNs from pcm 1205 pcm = pcm[:-self._nb_pruned_nodes, :-self._nb_pruned_nodes] 1206 1207 #check for consistency 1208 assert(self._nb_pruned_nodes>=0), "Internal error: number of \ 1209 pruned nodes must be positive." 1210 else: 1211 self._nb_pruned_nodes = 0 1212 # no pruning; same length as before 1213 self._n_pruned = encoder._n_ldpc 1214 1215 super().__init__(pcm, 1216 trainable, 1217 cn_type, 1218 hard_out, 1219 track_exit, 1220 num_iter=num_iter, 1221 stateful=stateful, 1222 output_dtype=output_dtype, 1223 **kwargs) 1224 1225 ######################################### 1226 # Public methods and properties 1227 ######################################### 1228 1229 @property 1230 def encoder(self): 1231 """LDPC Encoder used for rate-matching/recovery.""" 1232 return self._encoder 1233 1234 ######################### 1235 # Keras layer functions 1236 ######################### 1237 1238 def build(self, input_shape): 1239 """Build model.""" 1240 if self._stateful: 1241 assert(len(input_shape)==2), \ 1242 "For stateful decoding, a tuple of two inputs is expected." 1243 input_shape = input_shape[0] 1244 1245 # check input dimensions for consistency 1246 assert (input_shape[-1]==self.encoder.n), \ 1247 'Last dimension must be of length n.' 1248 assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' 1249 1250 self._old_shape_5g = input_shape 1251 1252 def call(self, inputs): 1253 """Iterative BP decoding function. 1254 1255 This function performs ``num_iter`` belief propagation decoding 1256 iterations and returns the estimated codeword. 1257 1258 Args: 1259 inputs (tf.float32): Tensor of shape `[...,n]` containing the 1260 channel logits/llr values. 1261 1262 Returns: 1263 `tf.float32`: Tensor of shape `[...,n]` or `[...,k]` 1264 (``return_infobits`` is True) containing bit-wise soft-estimates 1265 (or hard-decided bit-values) of all codeword bits (or info 1266 bits, respectively). 1267 1268 Raises: 1269 ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 1270 1271 ValueError: If ``num_iter`` is not an integer greater (or equal) 1272 `0`. 1273 1274 InvalidArgumentError: When rank(``inputs``)<2. 1275 """ 1276 1277 # Extract inputs 1278 if self._stateful: 1279 llr_ch, msg_vn = inputs 1280 else: 1281 llr_ch = inputs 1282 1283 tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 1284 1285 llr_ch_shape = llr_ch.get_shape().as_list() 1286 new_shape = [-1, llr_ch_shape[-1]] 1287 llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 1288 batch_size = tf.shape(llr_ch_reshaped)[0] 1289 1290 # invert if rate-matching output interleaver was applied as defined in 1291 # Sec. 5.4.2.2 in 38.212 1292 if self._encoder.num_bits_per_symbol is not None: 1293 llr_ch_reshaped = tf.gather(llr_ch_reshaped, 1294 self._encoder.out_int_inv, 1295 axis=-1) 1296 1297 1298 # undo puncturing of the first 2*Z bit positions 1299 llr_5g = tf.concat( 1300 [tf.zeros([batch_size, 2*self.encoder.z], self._output_dtype), 1301 llr_ch_reshaped], 1302 1) 1303 1304 # undo puncturing of the last positions 1305 # total length must be n_ldpc, while llr_ch has length n 1306 # first 2*z positions are already added 1307 # -> add n_ldpc - n - 2Z punctured positions 1308 k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits 1309 nb_punc_bits = ((self.encoder.n_ldpc - k_filler) 1310 - self.encoder.n - 2*self.encoder.z) 1311 1312 1313 llr_5g = tf.concat([llr_5g, 1314 tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes], 1315 self._output_dtype)], 1316 1) 1317 1318 # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max) 1319 # the first k positions are the systematic bits 1320 x1 = tf.slice(llr_5g, [0,0], [batch_size, self.encoder.k]) 1321 1322 # parity part 1323 nb_par_bits = (self.encoder.n_ldpc - k_filler 1324 - self.encoder.k - self._nb_pruned_nodes) 1325 x2 = tf.slice(llr_5g, 1326 [0, self.encoder.k], 1327 [batch_size, nb_par_bits]) 1328 1329 # negative sign due to logit definition 1330 z = -tf.cast(self._llr_max, self._output_dtype) \ 1331 * tf.ones([batch_size, k_filler], self._output_dtype) 1332 1333 llr_5g = tf.concat([x1, z, x2], 1) 1334 1335 # and execute the decoder 1336 if not self._stateful: 1337 x_hat = super().call(llr_5g) 1338 else: 1339 x_hat,msg_vn = super().call([llr_5g, msg_vn]) # pylint: disable=used-before-assignment 1340 1341 if self._return_infobits: # return only info bits 1342 # reconstruct u_hat # code is systematic 1343 u_hat = tf.slice(x_hat, [0,0], [batch_size, self.encoder.k]) 1344 # Reshape u_hat so that it matches the original input dimensions 1345 output_shape = llr_ch_shape[0:-1] + [self.encoder.k] 1346 # overwrite first dimension as this could be None (Keras) 1347 output_shape[0] = -1 1348 u_reshaped = tf.reshape(u_hat, output_shape) 1349 1350 # enable other output datatypes than tf.float32 1351 u_out = tf.cast(u_reshaped, self._output_dtype) 1352 1353 if not self._stateful: 1354 return u_out 1355 else: 1356 return u_out, msg_vn 1357 1358 else: # return all codeword bits 1359 # the transmitted CW bits are not the same as used during decoding 1360 # cf. last parts of 5G encoding function 1361 1362 # remove last dim 1363 x = tf.reshape(x_hat, [batch_size, self._n_pruned]) 1364 1365 # remove filler bits at pos (k, k_ldpc) 1366 x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k]) 1367 1368 x_no_filler2 = tf.slice(x, 1369 [0, self.encoder.k_ldpc], 1370 [batch_size, 1371 self._n_pruned-self.encoder.k_ldpc]) 1372 1373 x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1) 1374 1375 # shorten the first 2*Z positions and end after n bits 1376 x_short = tf.slice(x_no_filler, 1377 [0, 2*self.encoder.z], 1378 [batch_size, self.encoder.n]) 1379 1380 # if used, apply rate-matching output interleaver again as 1381 # Sec. 5.4.2.2 in 38.212 1382 if self._encoder.num_bits_per_symbol is not None: 1383 x_short = tf.gather(x_short, self._encoder.out_int, axis=-1) 1384 1385 # Reshape x_short so that it matches the original input dimensions 1386 # overwrite first dimension as this could be None (Keras) 1387 llr_ch_shape[0] = -1 1388 x_short= tf.reshape(x_short, llr_ch_shape) 1389 1390 # enable other output datatypes than tf.float32 1391 x_out = tf.cast(x_short, self._output_dtype) 1392 1393 if not self._stateful: 1394 return x_out 1395 else: 1396 return x_out, msg_vn