encoding.py (29176B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers for Polar encoding including 5G compliant rate-matching and CRC 6 concatenation.""" 7 8 from sionna.fec.crc import CRCEncoder 9 from sionna.fec.polar.utils import generate_5g_ranking 10 from numpy.core.numerictypes import issubdtype 11 import tensorflow as tf 12 import numpy as np 13 from tensorflow.keras.layers import Layer 14 import numbers 15 16 class PolarEncoder(Layer): 17 """PolarEncoder(frozen_pos, n, dtype=tf.float32) 18 19 Polar encoder for given code parameters. 20 21 This layer performs polar encoding for the given ``k`` information bits and 22 the `frozen set` (i.e., indices of frozen positions) specified by 23 ``frozen_pos``. 24 25 The class inherits from the Keras layer class and can be used as layer in a 26 Keras model. 27 28 Parameters 29 ---------- 30 frozen_pos: ndarray 31 Array of `int` defining the `n-k` frozen indices, i.e., information 32 bits are mapped onto the `k` complementary positions. 33 34 n: int 35 Defining the codeword length. 36 37 dtype: tf.DType 38 Defaults to `tf.float32`. Defines the output datatype of the layer 39 (internal precision is `tf.uint8`). 40 41 Input 42 ----- 43 inputs: [...,k], tf.float32 44 2+D tensor containing the information bits to be encoded. 45 46 Output 47 ------ 48 : [...,n], tf.float32 49 2+D tensor containing the codeword bits. 50 51 Raises 52 ------ 53 AssertionError 54 ``k`` and ``n`` must be positive integers and ``k`` must be smaller 55 (or equal) than ``n``. 56 57 AssertionError 58 If ``n`` is not a power of 2. 59 60 AssertionError 61 If the number of elements in ``frozen_pos`` is great than ``n``. 62 63 AssertionError 64 If ``frozen_pos`` does not consists of `int`. 65 66 ValueError 67 If ``dtype`` is not supported. 68 69 ValueError 70 If ``inputs`` contains other values than `0` or `1`. 71 72 TypeError 73 If ``inputs`` is not `tf.float32`. 74 75 InvalidArgumentError 76 When rank(``inputs``)<2. 77 78 InvalidArgumentError 79 When shape of last dim is not ``k``. 80 81 Note 82 ---- 83 As commonly done, we assume frozen bits are set to `0`. Please note 84 that - although its practical relevance is only little - setting frozen 85 bits to `1` may result in `affine` codes instead of linear code as the 86 `all-zero` codeword is not necessarily part of the code any more. 87 """ 88 89 def __init__(self, 90 frozen_pos, 91 n, 92 dtype=tf.float32): 93 94 if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8, 95 tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32): 96 raise ValueError("Unsupported dtype.") 97 98 super().__init__(dtype=dtype) 99 100 assert isinstance(n, numbers.Number), "n must be a number." 101 n = int(n) # n can be float (e.g. as result of n=k*r) 102 assert issubdtype(frozen_pos.dtype, int), "frozen_pos must \ 103 consist of ints." 104 assert len(frozen_pos)<=n, "Number of elements in frozen_pos cannot \ 105 be greater than n." 106 107 assert np.log2(n)==int(np.log2(n)), "n must be a power of 2." 108 109 self._k = n - len(frozen_pos) 110 self._n = n 111 self._frozen_pos = frozen_pos 112 113 # generate info positions 114 self._info_pos = np.setdiff1d(np.arange(self._n), frozen_pos) 115 assert self._k==len(self._info_pos), "Internal error: invalid " \ 116 "info_pos generated." 117 118 self._check_input = True # check input for bin. values during first call 119 120 self._nb_stages = int(np.log2(self._n)) 121 self._ind_gather = self._gen_indices(self._n) 122 123 ######################################### 124 # Public methods and properties 125 ######################################### 126 127 @property 128 def k(self): 129 """Number of information bits.""" 130 return self._k 131 132 @property 133 def n(self): 134 """Codeword length.""" 135 return self._n 136 137 @property 138 def frozen_pos(self): 139 """Frozen positions for Polar decoding.""" 140 return self._frozen_pos 141 142 @property 143 def info_pos(self): 144 """Information bit positions for Polar encoding.""" 145 return self._info_pos 146 147 ######################### 148 # Utility methods 149 ######################### 150 151 def _gen_indices(self, n): 152 """Pre-calculate encoding indices stage-wise for tf.gather. 153 """ 154 155 nb_stages = int(np.log2(n)) 156 # last position denotes empty placeholder (points to element n+1) 157 ind_gather = np.ones([nb_stages, n+1]) * n 158 159 for s in range(nb_stages): 160 ind_range = np.arange(int(n/2)) 161 ind_dest = ind_range * 2 - np.mod(ind_range, 2**(s)) 162 ind_origin = ind_dest + 2**s 163 ind_gather[s, ind_dest] = ind_origin # and update gather indices 164 165 ind_gather = tf.constant(ind_gather, dtype=tf.int32) 166 167 return ind_gather 168 169 ######################### 170 # Keras layer functions 171 ######################### 172 173 def build(self, input_shape): 174 """build and check if ``k`` and ``input_shape`` match.""" 175 assert (input_shape[-1]==self._k), "Invalid input shape." 176 177 def call(self, inputs): 178 """Polar encoding function. 179 180 This function returns the polar encoded codewords for the given 181 information bits ``inputs``. 182 183 Args: 184 inputs (tf.float32): Tensor of shape `[...,k]` containing the 185 information bits to be encoded. 186 187 Returns: 188 `tf.float32`: Tensor of shape `[...,n]`. 189 190 Raises: 191 ValueError: If ``inputs`` contains other values than `0` or `1`. 192 193 TypeError: If ``inputs`` is not `tf.float32`. 194 195 InvalidArgumentError: When rank(``inputs``)<2. 196 197 InvalidArgumentError: When shape of last dim is not ``k``. 198 """ 199 200 tf.debugging.assert_type(inputs, self.dtype, 201 "Invalid input dtype.") 202 203 # Reshape inputs to [...,k] 204 tf.debugging.assert_greater(tf.rank(inputs), 1) 205 input_shape = inputs.shape 206 new_shape = [-1, input_shape[-1]] 207 u = tf.reshape(inputs, new_shape) 208 209 # last dim must be of length k 210 tf.debugging.assert_equal(tf.shape(u)[-1], 211 self._k, 212 "Last dimension must be of length k.") 213 214 # assert if binary=True and u is non binary 215 if self._check_input: 216 u_test = tf.cast(u, tf.float32) # only for internal check 217 tf.debugging.assert_equal(tf.reduce_min( 218 tf.cast( 219 tf.logical_or( 220 tf.equal(u_test, 0.), 221 tf.equal(u_test, 1.)), 222 tf.float32)), 223 1., 224 "Input must be binary.") 225 # input datatype consistency should be only evaluated once 226 self._check_input = False 227 228 # copy info bits to information set; other positions are frozen (=0) 229 230 # return an all-zero tensor of shape [n,...] 231 c = tf.zeros([self._n, tf.shape(u)[0]], self.dtype) 232 233 # u has shape bs x k, we now want k x bs 234 u_transpose = tf.transpose(u, (1,0)) # batch dim to last pos 235 236 # index vector has at least two axis (= index_depth) 237 info_pos_tf = tf.expand_dims(self.info_pos, axis=1) 238 239 c = tf.tensor_scatter_nd_update(c, info_pos_tf, u_transpose) 240 c = tf.transpose(c, (1,0)) 241 x_nan = tf.zeros([tf.shape(c)[0] ,1], self.dtype) 242 x = tf.concat([c, x_nan], 1) 243 x = tf.cast(x, tf.uint8) 244 245 # loop over all stages 246 for s in range(self._nb_stages): 247 ind_helper = self._ind_gather[s,:] 248 x_add = tf.gather(x, ind_helper, batch_dims=0, axis=1) 249 #x = tf.math.logical_xor(x, x_add) # does not work well with XLA 250 x = tf.bitwise.bitwise_xor(x, x_add) 251 252 # remove last position 253 c_out = x[:,0:self._n] 254 255 # restore original shape 256 input_shape_list = input_shape.as_list() 257 output_shape = input_shape_list[0:-1] + [self._n] 258 output_shape[0] = -1 # to support dynamic shapes 259 c_reshaped = tf.reshape(c_out, output_shape) 260 261 # cast to dtype for compatibility with other components 262 return tf.cast(c_reshaped, self.dtype) 263 264 class Polar5GEncoder(PolarEncoder): 265 # pylint: disable=line-too-long 266 """Polar5GEncoder(k, n, verbose=False, channel_type="uplink", dtype=tf.float32) 267 268 5G compliant Polar encoder including rate-matching following [3GPPTS38212]_ 269 for the uplink scenario (`UCI`) and downlink scenario (`DCI`). 270 271 This layer performs polar encoding for ``k`` information bits and 272 rate-matching such that the codeword lengths is ``n``. This includes the CRC 273 concatenation and the interleaving as defined in [3GPPTS38212]_. 274 275 Note: `block segmentation` is currently not supported (`I_seq=False`). 276 277 We follow the basic structure from Fig. 6 in [Bioglio_Design]_. 278 279 .. figure:: ../figures/PolarEncoding5G.png 280 281 Fig. 1: Implemented 5G Polar encoding chain following Fig. 6 in 282 [Bioglio_Design]_ for the uplink (`I_BIL` = `True`) and the downlink 283 (`I_IL` = `True`) scenario without `block segmentation`. 284 285 For further details, we refer to [3GPPTS38212]_, [Bioglio_Design]_ and 286 [Hui_ChannelCoding]_. 287 288 The class inherits from the Keras layer class and can be used as layer in a 289 Keras model. Further, the class inherits from PolarEncoder. 290 291 Parameters 292 ---------- 293 k: int 294 Defining the number of information bit per codeword. 295 296 n: int 297 Defining the codeword length. 298 299 channel_type: str 300 Defaults to "uplink". Can be "uplink" or "downlink". 301 302 verbose: bool 303 Defaults to False. If True, rate-matching parameters will be 304 printed. 305 306 dtype: tf.DType 307 Defaults to tf.float32. Defines the output datatype of the layer 308 (internal precision remains tf.uint8). 309 310 Input 311 ----- 312 inputs: [...,k], tf.float32 313 2+D tensor containing the information bits to be encoded. 314 315 Output 316 ------ 317 : [...,n], tf.float32 318 2+D tensor containing the codeword bits. 319 320 Raises 321 ------ 322 AssertionError 323 ``k`` and ``n`` must be positive integers and ``k`` must be smaller 324 (or equal) than ``n``. 325 326 AssertionError 327 If ``n`` and ``k`` are invalid code parameters (see [3GPPTS38212]_). 328 329 AssertionError 330 If ``verbose`` is not `bool`. 331 332 ValueError 333 If ``dtype`` is not supported. 334 335 Note 336 ---- 337 The encoder supports the `uplink` Polar coding (`UCI`) scheme from 338 [3GPPTS38212]_ and the `downlink` Polar coding (`DCI`) [3GPPTS38212]_, 339 respectively. 340 341 For `12 <= k <= 19` the 3 additional parity bits as defined in 342 [3GPPTS38212]_ are not implemented as it would also require a 343 modified decoding procedure to materialize the potential gains. 344 345 `Code segmentation` is currently not supported and, thus, ``n`` is 346 limited to a maximum length of 1088 codeword bits. 347 348 For the downlink scenario, the input length is limited to `k <= 140` 349 information bits due to the limited input bit interleaver size 350 [3GPPTS38212]_. 351 352 For simplicity, the implementation does not exactly re-implement the 353 `DCI` scheme from [3GPPTS38212]_. This implementation neglects the 354 `all-one` initialization of the CRC shift register and the scrambling of the CRC parity bits with the `RNTI`. 355 """ 356 357 def __init__(self, 358 k, 359 n, 360 channel_type="uplink", 361 verbose=False, 362 dtype=tf.float32,): 363 364 if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8, 365 tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32): 366 raise ValueError("Unsupported dtype.") 367 368 assert isinstance(k, numbers.Number), "k must be a number." 369 assert isinstance(n, numbers.Number), "n must be a number." 370 k = int(k) # k or n can be float (e.g. as result of n=k*r) 371 n = int(n) # k or n can be float (e.g. as result of n=k*r) 372 assert n>=k, "Invalid coderate (>1)." 373 assert isinstance(verbose, bool), "verbose must be bool." 374 375 assert channel_type in ("uplink","downlink"), \ 376 "Unsupported channel_type." 377 self._channel_type = channel_type 378 379 self._k_target = k 380 self._n_target = n 381 self._verbose = verbose 382 383 # Initialize rate-matcher 384 crc_degree, n_polar, frozen_pos, idx_rm, idx_input = \ 385 self._init_rate_match(k, n) 386 387 self._frozen_pos = frozen_pos # Required for decoder 388 self._ind_rate_matching = idx_rm # Index for gather-based rate-matching 389 self._ind_input_int = idx_input # Index for input interleaver 390 391 # Initialize CRC encoder 392 self._enc_crc = CRCEncoder(crc_degree, dtype=dtype) 393 394 # Init super-class (PolarEncoder) 395 super().__init__(frozen_pos, n_polar, dtype=dtype) 396 397 ######################################### 398 # Public methods and properties 399 ######################################### 400 401 @property 402 def enc_crc(self): 403 """CRC encoder layer used for CRC concatenation.""" 404 return self._enc_crc 405 406 @property 407 def k_target(self): 408 """Number of information bits including rate-matching.""" 409 return self._k_target 410 411 @property 412 def n_target(self): 413 """Codeword length including rate-matching.""" 414 return self._n_target 415 416 @property 417 def k_polar(self): 418 """Number of information bits of the underlying Polar code.""" 419 return self._k 420 421 @property 422 def n_polar(self): 423 """Codeword length of the underlying Polar code.""" 424 return self._n 425 426 @property 427 def k(self): 428 """Number of information bits including rate-matching.""" 429 return self._k_target 430 431 @property 432 def n(self): 433 """Codeword length including rate-matching.""" 434 return self._n_target 435 436 def subblock_interleaving(self, u): 437 """Input bit interleaving as defined in Sec 5.4.1.1 [3GPPTS38212]_. 438 439 Input 440 ----- 441 u: ndarray 442 1D array to be interleaved. Length of ``u`` must be a multiple 443 of 32. 444 445 Output 446 ------ 447 : ndarray 448 Interleaved version of ``u`` with same shape and dtype as ``u``. 449 450 Raises 451 ------ 452 AssertionError 453 If length of ``u`` is not a multiple of 32. 454 455 """ 456 457 k = u.shape[-1] 458 assert np.mod(k,32)==0, \ 459 "length for sub-block interleaving must be a multiple of 32." 460 y = np.zeros_like(u) 461 462 # Permutation according to Tab 5.4.1.1.1-1 in 38.212 463 perm = np.array([0, 1, 2, 4, 3, 5, 6, 7, 8, 16, 9, 17, 10, 18, 11, 19, 464 12, 20, 13, 21, 14, 22, 15, 23, 24, 25, 26, 28, 27, 465 29, 30, 31]) 466 467 for n in range(k): 468 i = int(np.floor(32*n/k)) 469 j = perm[i] * k/32 + np.mod(n, k/32) 470 j = int(j) 471 y[n] = u[j] 472 473 return y 474 475 def channel_interleaver(self, c): 476 """Triangular interleaver following Sec. 5.4.1.3 in [3GPPTS38212]_. 477 478 Input 479 ----- 480 c: ndarray 481 1D array to be interleaved. 482 483 Output 484 ------ 485 : ndarray 486 Interleaved version of ``c`` with same shape and dtype as ``c``. 487 488 """ 489 490 n = c.shape[-1] # Denoted as E in 38.212 491 c_int = np.zeros_like(c) 492 493 # Find smallest T s.t. T*(T+1)/2 >= n 494 t = 0 495 while t*(t+1)/2 < n: 496 t +=1 497 498 v = np.zeros([t, t]) 499 ind_k = 0 500 for ind_i in range(t): 501 for ind_j in range(t-ind_i): 502 if ind_k < n: 503 v[ind_i, ind_j] = c[ind_k] 504 else: 505 v[ind_i, ind_j] = np.nan # NULL 506 # Store nothing otherwise 507 ind_k += 1 508 ind_k = 0 509 for ind_j in range(t): 510 for ind_i in range(t-ind_j): 511 if not np.isnan(v[ind_i, ind_j]): 512 c_int[ind_k] = v[ind_i, ind_j] 513 ind_k += 1 514 return c_int 515 516 def input_interleaver(self, c): 517 """Input interleaver following Sec. 5.4.1.1 in [3GPPTS38212]_. 518 519 Input 520 ----- 521 c: ndarray 522 1D array to be interleaved. 523 524 Output 525 ------ 526 : ndarray 527 Interleaved version of ``c`` with same shape and dtype as ``c``. 528 529 """ 530 # 38.212 Table 5.3.1.1-1 531 p_il_max_table = [0, 2, 4, 7, 9, 14, 19, 20, 24, 25, 26, 28, 31, 34, 532 42, 45, 49, 50, 51, 53, 54, 56, 58, 59, 61, 62, 65, 66, 67, 69, 533 70, 71, 72, 76, 77, 81, 82, 83, 87, 88, 89, 91, 93, 95, 98, 101, 534 104, 106, 108, 110, 111, 113, 115, 118, 119, 120, 122, 123, 126, 535 127, 129, 132, 134, 138, 139, 140, 1, 3, 5, 8, 10, 15, 21, 27, 29, 536 32, 35, 43, 46, 52, 55, 57, 60, 63, 68, 73, 78, 84, 90, 92, 94, 96, 537 99, 102, 105, 107, 109, 112, 114, 116, 121, 124, 128, 130, 133, 538 135, 141, 6, 11, 16, 22, 30, 33, 36, 44, 47, 64, 74, 79, 85, 97, 539 100, 103, 117, 125, 131, 136, 142, 12, 17, 23, 37, 48, 75, 80, 86, 540 137, 143, 13, 18, 38, 144, 39, 145, 40, 146, 41, 147, 148, 149, 541 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 542 163] 543 k_il_max = 164 544 k = len(c) 545 assert k<=k_il_max, "Input interleaver only defined for length of 164." 546 c_apo = np.empty(k, 'int') 547 i = 0 548 for p_il_max in p_il_max_table: 549 if p_il_max >= (k_il_max - k): 550 c_apo[i] = c[p_il_max - (k_il_max - k)] 551 i += 1 552 return c_apo 553 554 ######################### 555 # Utility methods 556 ######################### 557 558 def _init_rate_match(self, k_target, n_target): 559 """Implementing polar rate matching according to [3GPPTS38212]_. 560 561 Please note that this part of the code only runs during the 562 initialization and, thus, is not performance critical. For easier 563 alignment and traceability with the standard document [3GPPTS38212]_ 564 the implementation prefers `for loop`-based indexing. 565 566 The relation of terminology between [3GPPTS38212]_ and this code is 567 given as: 568 `A`...`k_target` 569 `E`...`n_target` 570 `K`...`k_polar` 571 `N`...`n_polar` 572 `L`...`k_crc`. 573 """ 574 575 # Check input for consistency (see Sec. 6.3.1.2.1 for UL) 576 577 # currently not relevant (segmentation not supported) 578 # assert k_target<=1706, "Maximum supported codeword length for" \ 579 # "Polar coding is 1706." 580 581 assert n_target >= k_target, "n must be larger or equal k." 582 assert n_target >= 18, \ 583 "n<18 is not supported by the 5G Polar coding scheme." 584 assert k_target <= 1013, \ 585 "k too large - no codeword segmentation supported at the moment." 586 assert n_target <= 1088, \ 587 "n too large - no codeword segmentation supported at the moment." 588 589 # Select CRC polynomials (see Sec. 6.3.1.2.1 for UL) 590 if self._channel_type=="uplink": 591 592 if 12<=k_target<=19: 593 crc_pol = "CRC6" 594 k_crc = 6 595 elif k_target >=20: 596 crc_pol = "CRC11" 597 k_crc = 11 598 else: 599 raise ValueError("k_target<12 is not supported in 5G NR for " \ 600 "the uplink; please use 'channel coding of small block " \ 601 "lengths' scheme from Sec. 5.3.3 in 3GPP 38.212 instead.") 602 603 # PC bit for k_target = 12-19 bits (see Sec. 6.3.1.3.1 for UL) 604 n_pc = 0 605 #n_pc_wm = 0 606 if k_target<=19: 607 #n_pc = 3 608 n_pc = 0 # Currently deactivated 609 print("Warning: For 12<=k<=19 additional 3 parity-check bits " \ 610 "are defined in 38.212. They are currently not " \ 611 "implemented by this encoder and, thus, ignored.") 612 if n_target-k_target>175: 613 #n_pc_wm = 1 # not implemented 614 pass 615 616 else: # downlink channel 617 # for downlink CRC24 is used 618 # remark: in PDCCH messages are limited to k=140 619 # as the input interleaver does not support longer sequences 620 assert k_target <= 140, \ 621 "k too large for downlink channel configuration." 622 assert n_target >= 25, \ 623 "n too small for downlink channel configuration with 24 bit " \ 624 "CRC." 625 assert n_target <= 576, \ 626 "n too large for downlink channel configuration." 627 crc_pol = "CRC24C" # following 7.3.2 628 k_crc = 24 629 n_pc = 0 630 631 # No input interleaving for uplink needed 632 633 # Calculate Polar payload length (CRC bits are treated as info bits) 634 k_polar = k_target + k_crc + n_pc 635 636 assert k_polar <= n_target, "Device is not expected to be configured " \ 637 "with k_polar + k_crc + n_pc > n_target." 638 639 # Select polar mother code length n_polar 640 n_min = 5 641 n_max = 10 # For uplink; otherwise 9 642 643 # Select rate-matching scheme following Sec. 5.3.1 644 if (n_target <= ((9/8) * 2**(np.ceil(np.log2(n_target))-1)) and 645 k_polar/n_target < 9/16): 646 n1 = np.ceil(np.log2(n_target))-1 647 else: 648 n1 = np.ceil(np.log2(n_target)) 649 n2 = np.ceil(np.log2(8*k_polar)) #Lower bound such that rate > 1/8 650 n_polar = int(2**np.max((np.min([n1, n2, n_max]), n_min))) 651 652 # Puncturing and shortening as defined in Sec. 5.4.1.1 653 prefrozen_pos = [] # List containing the pre-frozen indices 654 if n_target < n_polar: 655 if k_polar/n_target <= 7/16: 656 # Puncturing 657 if self._verbose: 658 print("Using puncturing for rate-matching.") 659 n_int = 32 * np.ceil((n_polar-n_target) / 32) 660 int_pattern = self.subblock_interleaving(np.arange(n_int)) 661 for i in range(n_polar-n_target): 662 # Freeze additional bits 663 prefrozen_pos.append(int(int_pattern[i])) 664 if n_target >= 3*n_polar/4: 665 t = int(np.ceil(3/4*n_polar - n_target/2) - 1) 666 else: 667 t = int(np.ceil(9/16*n_polar - n_target/4) - 1) 668 # Extra freezing 669 for i in range(t): 670 prefrozen_pos.append(i) 671 else: 672 # Shortening ("through" sub-block interleaver) 673 if self._verbose: 674 print("Using shortening for rate-matching.") 675 n_int = 32 * np.ceil((n_polar) / 32) 676 int_pattern = self.subblock_interleaving(np.arange(n_int)) 677 for i in range(n_target, n_polar): 678 prefrozen_pos.append(int_pattern[i]) 679 680 # Remove duplicates 681 prefrozen_pos = np.unique(prefrozen_pos) 682 683 # Find the remaining n_polar - k_polar - |frozen_set| 684 685 # Load full channel ranking 686 ch_ranking, _ = generate_5g_ranking(0, n_polar, sort=False) 687 688 # Remove positions that are already frozen by `pre-freezing` stage 689 info_cand = np.setdiff1d(ch_ranking, prefrozen_pos, assume_unique=True) 690 691 # Identify k_polar most reliable positions from candidate positions 692 info_pos = [] 693 for i in range(k_polar): 694 info_pos.append(info_cand[-i-1]) 695 696 # Sort and create frozen positions for n_polar indices (no shortening) 697 info_pos = np.sort(info_pos).astype(int) 698 frozen_pos = np.setdiff1d(np.arange(n_polar), 699 info_pos, 700 assume_unique=True) 701 702 # For downlink only: generate input bit interleaver 703 if self._channel_type=="downlink": 704 if self._verbose: 705 print("Using input bit interleaver for downlink.") 706 ind_input_int = self.input_interleaver(np.arange(k_polar)) 707 else: 708 ind_input_int = None 709 710 # Generate tf.gather indices for sub-block interleaver 711 ind_sub_int = self.subblock_interleaving(np.arange(n_polar)) 712 713 # Rate matching via circular buffer as defined in Sec. 5.4.1.2 714 c_int = np.arange(n_polar) 715 idx_c_matched = np.zeros([n_target]) 716 if n_target >= n_polar: 717 # Repetition coding 718 if self._verbose: 719 print("Using repetition coding for rate-matching") 720 for ind in range(n_target): 721 idx_c_matched[ind] = c_int[np.mod(ind, n_polar)] 722 else: 723 if k_polar/n_target <= 7/16: 724 # Puncturing 725 for ind in range(n_target): 726 idx_c_matched[ind] = c_int[ind+n_polar-n_target] 727 else: 728 # Shortening 729 for ind in range(n_target): 730 idx_c_matched[ind] = c_int[ind] 731 732 # For uplink only: generate input bit interleaver 733 if self._channel_type=="uplink": 734 if self._verbose: 735 print("Using channel interleaver for uplink.") 736 ind_channel_int = self.channel_interleaver(np.arange(n_target)) 737 738 # Combine indices for single tf.gather operation 739 ind_t = idx_c_matched[ind_channel_int].astype(int) 740 idx_rate_matched = ind_sub_int[ind_t] 741 else: # no channel interleaver for downlink 742 idx_rate_matched = ind_sub_int[idx_c_matched.astype(int)] 743 744 if self._verbose: 745 print("Code parameters after rate-matching: " \ 746 f"k = {k_target}, n = {n_target}") 747 print(f"Polar mother code: k_polar = {k_polar}, " \ 748 f"n_polar = {n_polar}") 749 print("Using", crc_pol) 750 print("Frozen positions: ", frozen_pos) 751 print("Channel type: " + self._channel_type) 752 753 return crc_pol, n_polar, frozen_pos, idx_rate_matched, ind_input_int 754 755 ######################### 756 # Keras layer functions 757 ######################### 758 759 def build(self, input_shape): 760 """Build and check if ``k`` and ``input_shape`` match.""" 761 assert (input_shape[-1]==self._k_target), "Invalid input shape." 762 763 def call(self, inputs): 764 """Polar encoding function including rate-matching and CRC encoding. 765 766 This function returns the polar encoded codewords for the given 767 information bits ``inputs`` following [3GPPTS38212]_ including 768 rate-matching. 769 770 Args: 771 inputs (tf.float32): Tensor of shape `[...,k]` containing the 772 information bits to be encoded. 773 774 Returns: 775 `tf.float32`: Tensor of shape `[...,n]`. 776 777 Raises: 778 TypeError: If ``inputs`` is not `tf.float32`. 779 780 InvalidArgumentError: When rank(``inputs``)<2. 781 782 InvalidArgumentError: When shape of last dim is not ``k``. 783 """ 784 785 # Reshape inputs to [...,k] 786 tf.debugging.assert_greater(tf.rank(inputs), 1) 787 input_shape = inputs.shape 788 new_shape = [-1, input_shape[-1]] 789 u = tf.reshape(inputs, new_shape) 790 791 # Consistency check (i.e., binary) of inputs will be done in super_class 792 793 # CRC encode 794 u_crc = self._enc_crc(u) 795 796 # For downlink only: apply input bit interleaver 797 if self._channel_type=="downlink": 798 u_crc = tf.gather(u_crc, self._ind_input_int, axis=-1) 799 800 # Encode bits (= channel allocation + Polar transform) 801 c = super().call(u_crc) 802 803 # Sub-block interleaving with 32 sub-blocks as in Sec. 5.4.1.1 804 # Rate matching via circular buffer as defined in Sec. 5.4.1.2 805 # For uplink only: channel interleaving (i_bil=True) 806 c_matched = tf.gather(c, self._ind_rate_matching, axis=1) 807 808 # Restore original shape 809 input_shape_list = input_shape.as_list() 810 output_shape = input_shape_list[0:-1] + [self._n_target] 811 output_shape[0] = -1 # To support dynamic shapes 812 c_reshaped = tf.reshape(c_matched, output_shape) 813 814 return c_reshaped