encoding.py (26847B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers for LDPC channel encoding and utility functions.""" 6 7 import tensorflow as tf 8 import numpy as np 9 import scipy as sp 10 from tensorflow.keras.layers import Layer 11 from importlib_resources import files, as_file 12 from . import codes # pylint: disable=relative-beyond-top-level 13 import numbers # to check if n, k are numbers 14 15 from sionna.fec.linear import AllZeroEncoder as AllZeroEncoder_new 16 17 class LDPC5GEncoder(Layer): 18 # pylint: disable=line-too-long 19 """LDPC5GEncoder(k, n, num_bits_per_symbol=None, dtype=tf.float32, **kwargs) 20 21 5G NR LDPC Encoder following the 3GPP NR Initiative [3GPPTS38212_LDPC]_ 22 including rate-matching. 23 24 The class inherits from the Keras layer class and can be used as layer in a 25 Keras model. 26 27 Parameters 28 ---------- 29 k: int 30 Defining the number of information bit per codeword. 31 32 n: int 33 Defining the desired codeword length. 34 35 num_bits_per_symbol: int or None 36 Defining the number of bits per QAM symbol. If this parameter is 37 explicitly provided, the codeword will be interleaved after 38 rate-matching as specified in Sec. 5.4.2.2 in [3GPPTS38212_LDPC]_. 39 40 dtype: tf.DType 41 Defaults to `tf.float32`. Defines the output datatype of the layer 42 (internal precision remains `tf.uint8`). 43 44 Input 45 ----- 46 inputs: [...,k], tf.float32 47 2+D tensor containing the information bits to be 48 encoded. 49 50 Output 51 ------ 52 : [...,n], tf.float32 53 2+D tensor of same shape as inputs besides last dimension has 54 changed to `n` containing the encoded codeword bits. 55 56 Attributes 57 ---------- 58 k: int 59 Defining the number of information bit per codeword. 60 61 n: int 62 Defining the desired codeword length. 63 64 coderate: float 65 Defining the coderate r= ``k`` / ``n``. 66 67 n_ldpc: int 68 An integer defining the total codeword length (before 69 punturing) of the lifted parity-check matrix. 70 71 k_ldpc: int 72 An integer defining the total information bit length 73 (before zero removal) of the lifted parity-check matrix. Gap to 74 ``k`` must be filled with so-called filler bits. 75 76 num_bits_per_symbol: int or None. 77 Defining the number of bits per QAM symbol. If this parameter is 78 explicitly provided, the codeword will be interleaved after 79 rate-matching as specified in Sec. 5.4.2.2 in [3GPPTS38212_LDPC]_. 80 81 out_int: [n], ndarray of int 82 Defining the rate-matching output interleaver sequence. 83 84 out_int_inv: [n], ndarray of int 85 Defining the inverse rate-matching output interleaver sequence. 86 87 _check_input: bool 88 A boolean that indicates whether the input vector 89 during call of the layer should be checked for consistency (i.e., 90 binary). 91 92 _bg: str 93 Denoting the selected basegraph (either `bg1` or `bg2`). 94 95 _z: int 96 Denoting the lifting factor. 97 98 _i_ls: int 99 Defining which version of the basegraph to load. 100 Can take values between 0 and 7. 101 102 _k_b: int 103 Defining the number of `information bit columns` in the 104 basegraph. Determined by the code design procedure in 105 [3GPPTS38212_LDPC]_. 106 107 _bm: ndarray 108 An ndarray defining the basegraph. 109 110 _pcm: sp.sparse.csr_matrix 111 A sparse matrix of shape `[k_ldpc-n_ldpc, n_ldpc]` 112 containing the sparse parity-check matrix. 113 114 Raises 115 ------ 116 AssertionError 117 If ``k`` is not `int`. 118 119 AssertionError 120 If ``n`` is not `int`. 121 122 ValueError 123 If ``code_length`` is not supported. 124 125 ValueError 126 If `dtype` is not supported. 127 128 ValueError 129 If ``inputs`` contains other values than `0` or `1`. 130 131 InvalidArgumentError 132 When rank(``inputs``)<2. 133 134 InvalidArgumentError 135 When shape of last dim is not ``k``. 136 137 Note 138 ---- 139 As specified in [3GPPTS38212_LDPC]_, the encoder also performs 140 puncturing and shortening. Thus, the corresponding decoder needs to 141 `invert` these operations, i.e., must be compatible with the 5G 142 encoding scheme. 143 """ 144 145 def __init__(self, 146 k, 147 n, 148 num_bits_per_symbol=None, 149 dtype=tf.float32, 150 **kwargs): 151 152 super().__init__(dtype=dtype, **kwargs) 153 154 assert isinstance(k, numbers.Number), "k must be a number." 155 assert isinstance(n, numbers.Number), "n must be a number." 156 k = int(k) # k or n can be float (e.g. as result of n=k*r) 157 n = int(n) # k or n can be float (e.g. as result of n=k*r) 158 159 if dtype is not tf.float32: 160 print("Note: decoder uses tf.float32 for internal calculations.") 161 162 if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8, 163 tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32): 164 raise ValueError("Unsupported dtype.") 165 self._dtype = dtype 166 167 if k>8448: 168 raise ValueError("Unsupported code length (k too large).") 169 if k<12: 170 raise ValueError("Unsupported code length (k too small).") 171 172 if n>(316*384): 173 raise ValueError("Unsupported code length (n too large).") 174 if n<0: 175 raise ValueError("Unsupported code length (n negative).") 176 177 # init encoder parameters 178 self._k = k # number of input bits (= input shape) 179 self._n = n # the desired length (= output shape) 180 self._coderate = k / n 181 self._check_input = True # check input for consistency (i.e., binary) 182 183 # allow actual code rates slightly larger than 948/1024 184 # to account for the quantization procedure in 38.214 5.1.3.1 185 if self._coderate>(948/1024): # as specified in 38.212 5.4.2.1 186 print(f"Warning: effective coderate r>948/1024 for n={n}, k={k}.") 187 if self._coderate>(0.95): # as specified in 38.212 5.4.2.1 188 raise ValueError(f"Unsupported coderate (r>0.95) for n={n}, k={k}.") 189 if self._coderate<(1/5): 190 # outer rep. coding currently not supported 191 raise ValueError("Unsupported coderate (r<1/5).") 192 193 # construct the basegraph according to 38.212 194 self._bg = self._sel_basegraph(self._k, self._coderate) 195 self._z, self._i_ls, self._k_b = self._sel_lifting(self._k, self._bg) 196 self._bm = self._load_basegraph(self._i_ls, self._bg) 197 198 # total number of codeword bits 199 self._n_ldpc = self._bm.shape[1] * self._z 200 # if K_real < K _target puncturing must be applied earlier 201 self._k_ldpc = self._k_b * self._z 202 203 # construct explicit graph via lifting 204 pcm = self._lift_basegraph(self._bm, self._z) 205 206 pcm_a, pcm_b_inv, pcm_c1, pcm_c2 = self._gen_submat(self._bm, 207 self._k_b, 208 self._z, 209 self._bg) 210 211 # init sub-matrices for fast encoding ("RU"-method) 212 # note: dtype is tf.float32; 213 self._pcm = pcm # store the sparse parity-check matrix (for decoding) 214 215 # store indices for fast gathering (instead of explicit matmul) 216 self._pcm_a_ind = self._mat_to_ind(pcm_a) 217 self._pcm_b_inv_ind = self._mat_to_ind(pcm_b_inv) 218 self._pcm_c1_ind = self._mat_to_ind(pcm_c1) 219 self._pcm_c2_ind = self._mat_to_ind(pcm_c2) 220 221 self._num_bits_per_symbol = num_bits_per_symbol 222 if num_bits_per_symbol is not None: 223 self._out_int, self._out_int_inv = self.generate_out_int(self._n, 224 self._num_bits_per_symbol) 225 226 ######################################### 227 # Public methods and properties 228 ######################################### 229 230 @property 231 def k(self): 232 """Number of input information bits.""" 233 return self._k 234 235 @property 236 def n(self): 237 "Number of output codeword bits." 238 return self._n 239 240 @property 241 def coderate(self): 242 """Coderate of the LDPC code after rate-matching.""" 243 return self._coderate 244 245 @property 246 def k_ldpc(self): 247 """Number of LDPC information bits after rate-matching.""" 248 return self._k_ldpc 249 250 @property 251 def n_ldpc(self): 252 """Number of LDPC codeword bits before rate-matching.""" 253 return self._n_ldpc 254 255 @property 256 def pcm(self): 257 """Parity-check matrix for given code parameters.""" 258 return self._pcm 259 260 @property 261 def z(self): 262 """Lifting factor of the basegraph.""" 263 return self._z 264 265 @property 266 def num_bits_per_symbol(self): 267 """Modulation order used for the rate-matching output interleaver.""" 268 return self._num_bits_per_symbol 269 270 @property 271 def out_int(self): 272 """Output interleaver sequence as defined in 5.4.2.2.""" 273 return self._out_int 274 @property 275 def out_int_inv(self): 276 """Inverse output interleaver sequence as defined in 5.4.2.2.""" 277 return self._out_int_inv 278 279 ######################### 280 # Utility methods 281 ######################### 282 283 def generate_out_int(self, n, num_bits_per_symbol): 284 """"Generates LDPC output interleaver sequence as defined in 285 Sec 5.4.2.2 in [3GPPTS38212_LDPC]_. 286 287 Parameters 288 ---------- 289 n: int 290 Desired output sequence length. 291 292 num_bits_per_symbol: int 293 Number of symbols per QAM symbol, i.e., the modulation order. 294 295 Output 296 ------ 297 (perm_seq, perm_seq_inv): 298 Tuple: 299 300 perm_seq: ndarray of length n 301 Containing the permuted indices. 302 303 perm_seq_inv: ndarray of length n 304 Containing the inverse permuted indices. 305 306 Note 307 ---- 308 The interleaver pattern depends on the modulation order and helps to 309 reduce dependencies in bit-interleaved coded modulation (BICM) schemes. 310 """ 311 # allow float inputs, but verify that they represent integer 312 assert(n%1==0), "n must be int." 313 assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int." 314 n = int(n) 315 assert(n>0), "n must be a positive integer." 316 assert(num_bits_per_symbol>0), \ 317 "num_bits_per_symbol must be a positive integer." 318 num_bits_per_symbol = int(num_bits_per_symbol) 319 320 assert(n%num_bits_per_symbol==0),\ 321 "n must be a multiple of num_bits_per_symbol." 322 323 # pattern as defined in Sec 5.4.2.2 324 perm_seq = np.zeros(n, dtype=int) 325 for j in range(int(n/num_bits_per_symbol)): 326 for i in range(num_bits_per_symbol): 327 perm_seq[i + j*num_bits_per_symbol] \ 328 = int(i * int(n/num_bits_per_symbol) + j) 329 330 perm_seq_inv = np.argsort(perm_seq) 331 332 return perm_seq, perm_seq_inv 333 334 def _sel_basegraph(self, k, r): 335 """Select basegraph according to [3GPPTS38212_LDPC]_.""" 336 337 if k <= 292: 338 bg = "bg2" 339 elif k <= 3824 and r <= 0.67: 340 bg = "bg2" 341 elif r <= 0.25: 342 bg = "bg2" 343 else: 344 bg = "bg1" 345 346 # add for consistency 347 if bg=="bg1" and k>8448: 348 raise ValueError("K is not supported by BG1 (too large).") 349 350 if bg=="bg2" and k>3840: 351 raise ValueError( 352 f"K is not supported by BG2 (too large) k ={k}.") 353 354 if bg=="bg1" and r<1/3: 355 raise ValueError("Only coderate>1/3 supported for BG1. \ 356 Remark: Repetition coding is currently not supported.") 357 358 if bg=="bg2" and r<1/5: 359 raise ValueError("Only coderate>1/5 supported for BG2. \ 360 Remark: Repetition coding is currently not supported.") 361 362 return bg 363 364 def _load_basegraph(self, i_ls, bg): 365 """Helper to load basegraph from csv files. 366 367 ``i_ls`` is sub_index of the basegraph and fixed during lifting 368 selection. 369 """ 370 371 if i_ls > 7: 372 raise ValueError("i_ls too large.") 373 374 if i_ls < 0: 375 raise ValueError("i_ls cannot be negative.") 376 377 # csv files are taken from 38.212 and dimension is explicitly given 378 if bg=="bg1": 379 bm = np.zeros([46, 68]) - 1 # init matrix with -1 (None positions) 380 elif bg=="bg2": 381 bm = np.zeros([42, 52]) - 1 # init matrix with -1 (None positions) 382 else: 383 raise ValueError("Basegraph not supported.") 384 385 # and load the basegraph from csv format in folder "codes" 386 source = files(codes).joinpath(f"5G_{bg}.csv") 387 with as_file(source) as codes.csv: 388 bg_csv = np.genfromtxt(codes.csv, delimiter=";") 389 390 # reconstruct BG for given i_ls 391 r_ind = 0 392 for r in np.arange(2, bg_csv.shape[0]): 393 # check for next row index 394 if not np.isnan(bg_csv[r, 0]): 395 r_ind = int(bg_csv[r, 0]) 396 c_ind = int(bg_csv[r, 1]) # second column in csv is column index 397 value = bg_csv[r, i_ls + 2] # i_ls entries start at offset 2 398 bm[r_ind, c_ind] = value 399 400 return bm 401 402 def _lift_basegraph(self, bm, z): 403 """Lift basegraph with lifting factor ``z`` and shifted identities as 404 defined by the entries of ``bm``.""" 405 406 num_nonzero = np.sum(bm>=0) # num of non-neg elements in bm 407 408 # init all non-zero row/column indices 409 r_idx = np.zeros(z*num_nonzero) 410 c_idx = np.zeros(z*num_nonzero) 411 data = np.ones(z*num_nonzero) 412 413 # row/column indices of identity matrix for lifting 414 im = np.arange(z) 415 416 idx = 0 417 for r in range(bm.shape[0]): 418 for c in range(bm.shape[1]): 419 if bm[r,c]==-1: # -1 is used as all-zero matrix placeholder 420 pass #do nothing (sparse) 421 else: 422 # roll matrix by bm[r,c] 423 c_roll = np.mod(im+bm[r,c], z) 424 # append rolled identity matrix to pcm 425 r_idx[idx*z:(idx+1)*z] = r*z + im 426 c_idx[idx*z:(idx+1)*z] = c*z + c_roll 427 idx += 1 428 429 # generate lifted sparse matrix from indices 430 pcm = sp.sparse.csr_matrix((data,(r_idx, c_idx)), 431 shape=(z*bm.shape[0], z*bm.shape[1])) 432 return pcm 433 434 def _sel_lifting(self, k, bg): 435 """Select lifting as defined in Sec. 5.2.2 in [3GPPTS38212_LDPC]_. 436 437 We assume B < K_cb, thus B'= B and C = 1, i.e., no 438 additional CRC is appended. Thus, K' = B'/C = B and B is our K. 439 440 Z is the lifting factor. 441 i_ls is the set index ranging from 0...7 (specifying the exact bg 442 selection). 443 k_b is the number of information bit columns in the basegraph. 444 """ 445 # lifting set according to 38.212 Tab 5.3.2-1 446 s_val = [[2, 4, 8, 16, 32, 64, 128, 256], 447 [3, 6, 12, 24, 48, 96, 192, 384], 448 [5, 10, 20, 40, 80, 160, 320], 449 [7, 14, 28, 56, 112, 224], 450 [9, 18, 36, 72, 144, 288], 451 [11, 22, 44, 88, 176, 352], 452 [13, 26, 52, 104, 208], 453 [15, 30, 60, 120, 240]] 454 455 if bg == "bg1": 456 k_b = 22 457 else: 458 if k > 640: 459 k_b = 10 460 elif k > 560: 461 k_b = 9 462 elif k > 192: 463 k_b = 8 464 else: 465 k_b = 6 466 467 # find the min of Z from Tab. 5.3.2-1 s.t. k_b*Z>=K' 468 min_val = 100000 469 z = 0 470 i_ls = 0 471 i = -1 472 for s in s_val: 473 i += 1 474 for s1 in s: 475 x = k_b *s1 476 if x >= k: 477 # valid solution 478 if x < min_val: 479 min_val = x 480 z = s1 481 i_ls = i 482 483 # and set K=22*Z for bg1 and K=10Z for bg2 484 if bg == "bg1": 485 k_b = 22 486 else: 487 k_b = 10 488 489 return z, i_ls, k_b 490 491 def _gen_submat(self, bm, k_b, z, bg): 492 """Split the basegraph into multiple sub-matrices such that efficient 493 encoding is possible. 494 """ 495 g = 4 # code property (always fixed for 5G) 496 mb = bm.shape[0] # number of CN rows in basegraph (BG property) 497 498 bm_a = bm[0:g, 0:k_b] 499 bm_b = bm[0:g, k_b:(k_b+g)] 500 bm_c1 = bm[g:mb, 0:k_b] 501 bm_c2 = bm[g:mb, k_b:(k_b+g)] 502 503 # H could be sliced immediately (but easier to implement if based on B) 504 hm_a = self._lift_basegraph(bm_a, z) 505 506 # not required for encoding, but helpful for debugging 507 #hm_b = self._lift_basegraph(bm_b, z) 508 509 hm_c1 = self._lift_basegraph(bm_c1, z) 510 hm_c2 = self._lift_basegraph(bm_c2, z) 511 512 hm_b_inv = self._find_hm_b_inv(bm_b, z, bg) 513 514 return hm_a, hm_b_inv, hm_c1, hm_c2 515 516 def _find_hm_b_inv(self, bm_b, z, bg): 517 """ For encoding we need to find the inverse of `hm_b` such that 518 `hm_b^-1 * hm_b = I`. 519 520 Could be done sparse 521 For BG1 the structure of hm_b is given as (for all values of i_ls) 522 hm_b = 523 [P_A I 0 0 524 P_B I I 0 525 0 0 I I 526 P_A 0 0 I] 527 where P_B and P_A are Shifted identities. 528 529 The inverse can be found by solving a linear system of equations 530 hm_b_inv = 531 [P_B^-1, P_B^-1, P_B^-1, P_B^-1, 532 I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, 533 P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1, I+P_A*P_B^-1, 534 P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1]. 535 536 537 For bg2 the structure of hm_b is given as (for all values of i_ls) 538 hm_b = 539 [P_A I 0 0 540 0 I I 0 541 P_B 0 I I 542 P_A 0 0 I] 543 where P_B and P_A are Shifted identities 544 545 The inverse can be found by solving a linear system of equations 546 hm_b_inv = 547 [P_B^-1, P_B^-1, P_B^-1, P_B^-1, 548 I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, 549 I+P_A*P_B^-1, I+P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, 550 P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1] 551 552 Note: the inverse of B is simply a shifted identity matrix with 553 negative shift direction. 554 """ 555 556 # permutation indices 557 pm_a= int(bm_b[0,0]) 558 if bg=="bg1": 559 pm_b_inv = int(-bm_b[1, 0]) 560 else: # structure of B is slightly different for bg2 561 pm_b_inv = int(-bm_b[2, 0]) 562 563 hm_b_inv = np.zeros([4*z, 4*z]) 564 565 im = np.eye(z) 566 567 am = np.roll(im, pm_a, axis=1) 568 b_inv = np.roll(im, pm_b_inv, axis=1) 569 ab_inv = np.matmul(am, b_inv) 570 571 # row 0 572 hm_b_inv[0:z, 0:z] = b_inv 573 hm_b_inv[0:z, z:2*z] = b_inv 574 hm_b_inv[0:z, 2*z:3*z] = b_inv 575 hm_b_inv[0:z, 3*z:4*z] = b_inv 576 577 # row 1 578 hm_b_inv[z:2*z, 0:z] = im + ab_inv 579 hm_b_inv[z:2*z, z:2*z] = ab_inv 580 hm_b_inv[z:2*z, 2*z:3*z] = ab_inv 581 hm_b_inv[z:2*z, 3*z:4*z] = ab_inv 582 583 # row 2 584 if bg=="bg1": 585 hm_b_inv[2*z:3*z, 0:z] = ab_inv 586 hm_b_inv[2*z:3*z, z:2*z] = ab_inv 587 hm_b_inv[2*z:3*z, 2*z:3*z] = im + ab_inv 588 hm_b_inv[2*z:3*z, 3*z:4*z] = im + ab_inv 589 else: # for bg2 the structure is slightly different 590 hm_b_inv[2*z:3*z, 0:z] = im + ab_inv 591 hm_b_inv[2*z:3*z, z:2*z] = im + ab_inv 592 hm_b_inv[2*z:3*z, 2*z:3*z] = ab_inv 593 hm_b_inv[2*z:3*z, 3*z:4*z] = ab_inv 594 595 # row 3 596 hm_b_inv[3*z:4*z, 0:z] = ab_inv 597 hm_b_inv[3*z:4*z, z:2*z] = ab_inv 598 hm_b_inv[3*z:4*z, 2*z:3*z] = ab_inv 599 hm_b_inv[3*z:4*z, 3*z:4*z] = im + ab_inv 600 601 # return results as sparse matrix 602 return sp.sparse.csr_matrix(hm_b_inv) 603 604 def _mat_to_ind(self, mat): 605 """Helper to transform matrix into index representation for 606 tf.gather. An index pointing to the `last_ind+1` is used for non-existing edges due to irregular degrees.""" 607 m = mat.shape[0] 608 n = mat.shape[1] 609 610 # transpose mat for sorted column format 611 c_idx, r_idx, _ = sp.sparse.find(mat.transpose()) 612 613 # sort indices explicitly, as scipy.sparse.find changed from column to 614 # row sorting in scipy>=1.11 615 idx = np.argsort(r_idx) 616 c_idx = c_idx[idx] 617 r_idx = r_idx[idx] 618 619 # find max number of no-zero entries 620 n_max = np.max(mat.getnnz(axis=1)) 621 622 # init index array with n (pointer to last_ind+1, will be a default 623 # value) 624 gat_idx = np.zeros([m, n_max]) + n 625 626 r_val = -1 627 c_val = 0 628 for idx in range(len(c_idx)): 629 # check if same row or if a new row starts 630 if r_idx[idx] != r_val: 631 r_val = r_idx[idx] 632 c_val = 0 633 gat_idx[r_val, c_val] = c_idx[idx] 634 c_val += 1 635 636 gat_idx = tf.cast(tf.constant(gat_idx), tf.int32) 637 return gat_idx 638 639 def _matmul_gather(self, mat, vec): 640 """Implements a fast sparse matmul via gather function.""" 641 642 # add 0 entry for gather-reduce_sum operation 643 # (otherwise ragged Tensors are required) 644 bs = tf.shape(vec)[0] 645 vec = tf.concat([vec, tf.zeros([bs, 1], dtype=self.dtype)], 1) 646 647 retval = tf.gather(vec, mat, batch_dims=0, axis=1) 648 retval = tf.reduce_sum(retval, axis=-1) 649 650 return retval 651 652 def _encode_fast(self, s): 653 """Main encoding function based on gathering function.""" 654 p_a = self._matmul_gather(self._pcm_a_ind, s) 655 p_a = self._matmul_gather(self._pcm_b_inv_ind, p_a) 656 657 # calc second part of parity bits p_b 658 # second parities are given by C_1*s' + C_2*p_a' + p_b' = 0 659 p_b_1 = self._matmul_gather(self._pcm_c1_ind, s) 660 p_b_2 = self._matmul_gather(self._pcm_c2_ind, p_a) 661 p_b = p_b_1 + p_b_2 662 663 c = tf.concat([s, p_a, p_b], 1) 664 665 # faster implementation of mod-2 operation c = tf.math.mod(c, 2) 666 c_uint8 = tf.cast(c, tf.uint8) 667 c_bin = tf.bitwise.bitwise_and(c_uint8, tf.constant(1, tf.uint8)) 668 c = tf.cast(c_bin, self.dtype) 669 670 c = tf.expand_dims(c, axis=-1) # returns nx1 vector 671 return c 672 673 ######################### 674 # Keras layer functions 675 ######################### 676 677 def build(self, input_shape): 678 """"Build layer.""" 679 # check if k and input shape match 680 assert (input_shape[-1]==self._k), "Last dimension must be of length k." 681 assert (len(input_shape)>=2), "Rank of input must be at least 2." 682 683 def call(self, inputs): 684 """5G LDPC encoding function including rate-matching. 685 686 This function returns the encoded codewords as specified by the 3GPP NR Initiative [3GPPTS38212_LDPC]_ including puncturing and shortening. 687 688 Args: 689 inputs (tf.float32): Tensor of shape `[...,k]` containing the 690 information bits to be encoded. 691 692 Returns: 693 `tf.float32`: Tensor of shape `[...,n]`. 694 695 Raises: 696 ValueError: If ``inputs`` contains other values than `0` or `1`. 697 698 InvalidArgumentError: When rank(``inputs``)<2. 699 700 InvalidArgumentError: When shape of last dim is not ``k``. 701 """ 702 703 tf.debugging.assert_type(inputs, self.dtype, "Invalid input dtype.") 704 705 # Reshape inputs to [...,k] 706 input_shape = inputs.get_shape().as_list() 707 new_shape = [-1, input_shape[-1]] 708 u = tf.reshape(inputs, new_shape) 709 710 # assert if u is non binary 711 if self._check_input: 712 tf.debugging.assert_equal( 713 tf.reduce_min( 714 tf.cast( 715 tf.logical_or( 716 tf.equal(u, tf.constant(0, self.dtype)), 717 tf.equal(u, tf.constant(1, self.dtype)), 718 ), 719 self.dtype)), 720 tf.constant(1, self.dtype), 721 "Input must be binary.") 722 # input datatype consistency should be only evaluated once 723 self._check_input = False 724 725 batch_size = tf.shape(u)[0] 726 727 # add "filler" bits to last positions to match info bit length k_ldpc 728 u_fill = tf.concat([u, 729 tf.zeros([batch_size, self._k_ldpc-self._k], self.dtype)], 730 1) 731 732 # use optimized encoding based on tf.gather 733 c = self._encode_fast(u_fill) 734 735 c = tf.reshape(c, [batch_size, self._n_ldpc]) # remove last dim 736 737 # remove filler bits at pos (k, k_ldpc) 738 c_no_filler1 = tf.slice(c, [0, 0], [batch_size, self._k]) 739 c_no_filler2 = tf.slice(c, 740 [0, self._k_ldpc], 741 [batch_size, self._n_ldpc-self._k_ldpc]) 742 743 c_no_filler = tf.concat([c_no_filler1, c_no_filler2], 1) 744 745 # shorten the first 2*Z positions and end after n bits 746 # (remaining parity bits can be used for IR-HARQ) 747 c_short = tf.slice(c_no_filler, [0, 2*self._z], [batch_size, self.n]) 748 # incremental redundancy could be generated by accessing the last bits 749 750 # if num_bits_per_symbol is provided, apply output interleaver as 751 # specified in Sec. 5.4.2.2 in 38.212 752 if self._num_bits_per_symbol is not None: 753 c_short = tf.gather(c_short, self._out_int, axis=-1) 754 755 # Reshape c_short so that it matches the original input dimensions 756 output_shape = input_shape[0:-1] + [self.n] 757 output_shape[0] = -1 758 c_reshaped = tf.reshape(c_short, output_shape) 759 760 return tf.cast(c_reshaped, self._dtype) 761 762 763 ########################################################### 764 # Deprecated aliases that will not be included in the next 765 # major release 766 ########################################################### 767 768 def AllZeroEncoder(k, 769 n, 770 dtype=tf.float32, 771 **kwargs): 772 print("Warning: The alias fec.ldpc.AllZeroEncoder will not be included in "\ 773 "Sionna 1.0. Please use sionna.fec.linear.AllZeroEncoder instead.") 774 return AllZeroEncoder_new(k=k, 775 n=n, 776 dtype=dtype, 777 **kwargs)