decoding.py (37972B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for Convolutional Code Viterbi Decoding.""" 6 7 import numpy as np 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 from sionna.fec.utils import int2bin 11 from sionna.fec.conv.utils import polynomial_selector, Trellis 12 13 14 class ViterbiDecoder(Layer): 15 # pylint: disable=line-too-long 16 r"""ViterbiDecoder(encoder=None, gen_poly=None, rate=1/2, constraint_length=3, rsc=False, terminate=False, method='soft_llr', output_dtype=tf.float32, **kwargs) 17 18 Implements the Viterbi decoding algorithm [Viterbi]_ that returns an 19 estimate of the information bits for a noisy convolutional codeword. 20 Takes as input either LLR values (`method` = `soft_llr`) or hard bit values 21 (`method` = `hard`) and returns a hard decided estimation of the information 22 bits. 23 24 The class inherits from the Keras layer class and can be used as layer in 25 a Keras model. 26 27 Parameters 28 ---------- 29 encoder: :class:`~sionna.fec.conv.encoding.ConvEncoder` 30 If ``encoder`` is provided as input, the following input parameters 31 are not required and will be ignored: ``gen_poly``, ``rate``, 32 ``constraint_length``, ``rsc``, ``terminate``. They will be inferred 33 from the ``encoder`` object itself. If ``encoder`` is `None`, the 34 above parameters must be provided explicitly. 35 36 gen_poly: tuple 37 tuple of strings with each string being a 0, 1 sequence. If `None`, 38 ``rate`` and ``constraint_length`` must be provided. 39 40 rate: float 41 Valid values are 1/3 and 0.5. Only required if ``gen_poly`` is `None`. 42 43 constraint_length: int 44 Valid values are between 3 and 8 inclusive. Only required if 45 ``gen_poly`` is `None`. 46 47 rsc: boolean 48 Boolean flag indicating whether the encoder is recursive-systematic for 49 given generator polynomials. 50 `True` indicates encoder is recursive-systematic. 51 `False` indicates encoder is feed-forward non-systematic. 52 53 terminate: boolean 54 Boolean flag indicating whether the codeword is terminated. 55 `True` indicates codeword is terminated to all-zero state. 56 `False` indicates codeword is not terminated. 57 58 method: str 59 Valid values are `soft_llr` or `hard`. In computing path 60 metrics, 61 `soft_llr` expects channel LLRs as input 62 `hard` assumes a `binary symmetric channel` (BSC) with 0/1 values are 63 inputs. In case of `hard`, `inputs` will be quantized to 0/1 values. 64 65 output_dtype: tf.DType 66 Defaults to tf.float32. Defines the output datatype of the layer. 67 68 Input 69 ----- 70 inputs: [...,n], tf.float32 71 2+D tensor containing the (noisy) channel output symbols where `n` 72 denotes the codeword length 73 74 Output 75 ------ 76 : [...,rate*n], tf.float32 77 2+D tensor containing the estimates of the information bit tensor 78 79 Note 80 ---- 81 A full implementation of the decoder rather than a windowed approach 82 is used. For a given codeword of duration `T`, the path metric is 83 computed from time `0` to `T` and the path with optimal metric at time 84 `T` is selected. The optimal path is then traced back from `T` to `0` 85 to output the estimate of the information bit vector used to encode. 86 For larger codewords, note that the current method is sub-optimal 87 in terms of memory utilization and latency. 88 """ 89 90 def __init__(self, 91 encoder=None, 92 gen_poly=None, 93 rate=1/2, 94 constraint_length=3, 95 rsc=False, 96 terminate=False, 97 method='soft_llr', 98 return_info_bits=True, 99 output_dtype=tf.float32, 100 **kwargs): 101 102 super().__init__(**kwargs) 103 if encoder is not None: 104 self._gen_poly = encoder.gen_poly 105 self._trellis = encoder.trellis 106 self._terminate = encoder.terminate 107 else: 108 valid_rates = (1/2, 1/3) 109 valid_constraint_length = (3, 4, 5, 6, 7, 8) 110 111 if gen_poly is not None: 112 assert all(isinstance(poly, str) for poly in gen_poly), \ 113 "Each polynomial must be a string." 114 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \ 115 "Each polynomial must be of same length." 116 assert all(all( 117 char in ['0','1'] for char in poly) for poly in gen_poly),\ 118 "Each polynomial must be a string of 0's and 1's." 119 self._gen_poly = gen_poly 120 else: 121 valid_rates = (1/2, 1/3) 122 valid_constraint_length = (3, 4, 5, 6, 7, 8) 123 124 assert constraint_length in valid_constraint_length, \ 125 "Constraint length must be between 3 and 8." 126 assert rate in valid_rates, \ 127 "Rate must be 1/3 or 1/2." 128 self._gen_poly = polynomial_selector(rate, constraint_length) 129 130 # init Trellis parameters 131 self._trellis = Trellis(self.gen_poly, rsc=rsc) 132 self._terminate = terminate 133 134 self._coderate_desired = 1/len(self.gen_poly) 135 self._mu = len(self._gen_poly[0])-1 136 assert method in ('soft_llr', 'hard'), \ 137 "method must be `soft_llr` or `hard`." 138 139 # conv_k denotes number of input bit streams 140 # can only be 1 in current implementation 141 self._conv_k = self._trellis.conv_k 142 143 # conv_n denotes number of output bits for conv_k input bits 144 self._conv_n = self._trellis.conv_n 145 146 self._k = None 147 self._n = None 148 # num_syms denote number of encoding periods or state transitions. 149 self._num_syms = None 150 151 self._ni = 2**self._conv_k 152 self._no = 2**self._conv_n 153 self._ns = self._trellis.ns 154 155 self._method = method 156 self._return_info_bits = return_info_bits 157 self.output_dtype = output_dtype 158 # If i->j state transition emits symbol k, tf.gather with ipst_op_idx 159 # gathers (i,k) element from input in row j. 160 self.ipst_op_idx = self._mask_by_tonode() 161 162 ######################################### 163 # Public methods and properties 164 ######################################### 165 166 @property 167 def gen_poly(self): 168 """Generator polynomial used by the encoder""" 169 return self._gen_poly 170 171 @property 172 def coderate(self): 173 """Rate of the code used in the encoder""" 174 if self.terminate and self._n is None: 175 print("Note that, due to termination, the true coderate is lower "\ 176 "than the returned design rate. "\ 177 "The exact true rate is dependent on the value of n and "\ 178 "hence cannot be computed before the first call().") 179 self._coderate = self._coderate_desired 180 elif self.terminate and self._n is not None: 181 k = self._coderate_desired*self._n - self._mu 182 self._coderate = k/self._n 183 return self._coderate 184 185 @property 186 def trellis(self): 187 """Trellis object used during encoding""" 188 return self._trellis 189 190 @property 191 def terminate(self): 192 """Indicates if the encoder is terminated during codeword generation""" 193 return self._terminate 194 195 @property 196 def k(self): 197 """Number of information bits per codeword""" 198 if self._k is None: 199 print("Note: The value of k cannot be computed before the first " \ 200 "call().") 201 return self._k 202 203 @property 204 def n(self): 205 """Number of codeword bits""" 206 if self._n is None: 207 print("Note: The value of n cannot be computed before the first " \ 208 "call().") 209 return self._n 210 211 ######################### 212 # Utility functions 213 ######################### 214 215 def _mask_by_tonode(self): 216 r""" 217 _Ns x _No index matrix, each element of shape (2,) 218 where num_ops = 2**conv_n 219 When applied as tf.gather index on a Ns x num_ops matrix 220 ((i,j) denoting metric for prev_st=i and output=j) 221 the output is matrix sorted by next_state. Row i in output 222 denotes the 2 possible metrics for transition to state i. 223 """ 224 cnst = self._ns * self._ni 225 from_nodes_vec = tf.reshape(self._trellis.from_nodes,(cnst,)) 226 op_idx = tf.reshape(self._trellis.op_by_tonode, (cnst,)) 227 st_op_idx = tf.transpose(tf.stack([from_nodes_vec, op_idx])) 228 st_op_idx = tf.reshape(st_op_idx[None,:,:],(self._ns, self._ni, 2)) 229 230 return st_op_idx 231 232 def _update_fwd(self, init_cm, bm_mat): 233 state_vec = tf.tile(tf.range(self._ns, dtype=tf.int32)[None,:], 234 [tf.shape(init_cm)[0], 1]) 235 ipst_op_mask = tf.tile(self.ipst_op_idx[None,:], [tf.shape(init_cm)[0], 1, 1, 1]) 236 237 cm_ta = tf.TensorArray(tf.float32, size=self._num_syms, 238 dynamic_size=False, clear_after_read=False) 239 tb_ta = tf.TensorArray(tf.int32, size=self._num_syms, 240 dynamic_size=False, clear_after_read=False) 241 242 prev_cm = init_cm 243 for idx in tf.range(0, self._n, self._conv_n): 244 sym = idx//self._conv_n 245 metrics_t = bm_mat[..., sym] 246 # Ns x No matrix- (s,j) is path_metric at state s with transition op=j 247 sum_metric = prev_cm[:,:,None] + metrics_t[:,None,:] 248 sum_metric_bytonode = tf.gather_nd(sum_metric, ipst_op_mask, 249 batch_dims=1) 250 251 tb_state_idx = tf.math.argmin(sum_metric_bytonode, axis=2) 252 tb_state_idx = tf.cast(tb_state_idx, tf.int32) 253 254 # Transition to states argmin state index 255 from_st_idx = tf.transpose(tf.stack([state_vec, tb_state_idx]), 256 perm=[1, 2, 0]) 257 258 tb_states = tf.gather_nd(self._trellis.from_nodes, from_st_idx) 259 cum_t = tf.math.reduce_min(sum_metric_bytonode,axis=2) 260 261 cm_ta = cm_ta.write(sym, cum_t) 262 tb_ta = tb_ta.write(sym, tb_states) 263 264 prev_cm = cum_t 265 266 return cm_ta, tb_ta 267 268 269 def _op_bits_path(self, paths): 270 r""" 271 Given a path, compute the input bit stream that results in the path. 272 Used in call() where the input is optimal path (seq of states) such 273 as the path returned by _return_optimal. 274 """ 275 paths = tf.cast(paths, tf.int32) 276 ip_bits = tf.TensorArray(tf.int32, 277 size=paths.shape[-1]-1, 278 dynamic_size=False, 279 clear_after_read=False) 280 dec_syms = tf.TensorArray(tf.int32, 281 size=paths.shape[-1]-1, 282 dynamic_size=False, 283 clear_after_read=False) 284 ni = self._trellis.ni 285 ip_sym_mask = tf.range(ni)[None, :] 286 287 for sym in tf.range(1, paths.shape[-1]): 288 289 # gather index from paths to enable XLA 290 # replaces p_idx = paths[:,sym-1:sym+1] 291 p_idx = tf.gather(paths, [sym-1, sym], axis=-1) 292 dec_ = tf.gather_nd(self._trellis.op_mat, p_idx) 293 294 dec_syms = dec_syms.write(sym-1, value=dec_) 295 # bs x ni boolean tensor. Each row has a True and False. True 296 # corresponds to input_bit which produced the next state (t=sym) 297 match_st = tf.math.equal( 298 tf.gather(self._trellis.to_nodes,paths[:, sym-1]), 299 tf.tile(paths[:, sym][:, None], [1, 2]) 300 ) 301 302 # tf.boolean_mask throws error in XLA mode 303 #ip_bit = tf.boolean_mask(ip_sym_mask, match_st) 304 305 ip_bit_ = tf.where(match_st, 306 ip_sym_mask, 307 tf.zeros_like(ip_sym_mask)) 308 ip_bit = tf.reduce_sum(ip_bit_, axis=-1) 309 ip_bits = ip_bits.write(sym-1, ip_bit) 310 311 ip_bit_vec_est = tf.transpose(ip_bits.stack()) 312 ip_sym_vec_est = tf.transpose(dec_syms.stack()) 313 314 return ip_bit_vec_est, ip_sym_vec_est 315 316 def _optimal_path(self, cm_, tb_): 317 r""" 318 Compute optimal path (state at each time t) given tensors cm_ & tb_ 319 of shapes (None, Ns, T). Output is of shape (None, T) 320 cm_: cumulative metrics for each state at time t(0 to T) 321 tb_: traceback state for each state at time t(0 to T) 322 """ 323 # tb and ca are of shape (batch x self._ns x num_syms) 324 assert(tb_.get_shape()[1] == self._ns), "Invalid shape." 325 optst_ta = tf.TensorArray(tf.int32, size=tb_.shape[-1], 326 dynamic_size=False, 327 clear_after_read=False) 328 if self._terminate: 329 opt_term_state = tf.zeros((tf.shape(cm_)[0],), tf.int32) 330 else: 331 opt_term_state =tf.cast(tf.argmin(cm_[:, :, -1], axis=1), tf.int32) 332 optst_ta = optst_ta.write(tb_.shape[-1]-1,opt_term_state) 333 334 for sym in tf.range(tb_.shape[-1]-1, 0, -1): 335 opt_st = optst_ta.read(sym)[:,None] 336 337 idx_ = tf.concat([tf.range(tf.shape(cm_)[0])[:,None], opt_st], 338 axis=1) 339 opt_st_tminus1 = tf.gather_nd(tb_[:, :, sym], idx_) 340 341 optst_ta = optst_ta.write(sym-1, opt_st_tminus1) 342 343 return tf.transpose(optst_ta.stack()) 344 345 def _bmcalc(self, y): 346 """ 347 Calculate branch metrics for a given noisy codeword tensor. 348 For each time period t, _bmcalc computes the distance of symbol 349 vector y[t] from each possible output symbol. 350 The distance metric is L2 distance if decoder parameter `method` is 351 "soft". 352 353 The distance metric is L1 distance if parameter `method` is "hard". 354 """ 355 356 op_bits = np.stack( 357 [int2bin(op, self._conv_n) for op in range(self._no)]) 358 op_mat = tf.cast(tf.tile(op_bits, [1,self._num_syms]), tf.float32) 359 op_mat = tf.expand_dims(op_mat, axis=0) 360 y = tf.expand_dims(y, axis=1) 361 if self._method=='soft_llr': 362 op_mat_sign = 1 - 2.*op_mat 363 llr_sign = -1. * tf.math.multiply(y, op_mat_sign) 364 llr_sign = tf.reshape(llr_sign, 365 (-1, self._no, self._num_syms, self._conv_n)) 366 # Sum of LLR*(sign of bit) for each symbol 367 bm = tf.math.reduce_sum(llr_sign, axis=-1) 368 369 else: # method == 'hard' 370 diffabs = tf.math.abs(y-op_mat) 371 diffabs = tf.reshape(diffabs, 372 (-1, self._no, self._num_syms, self._conv_n)) 373 # Manhattan distance of symbols 374 bm = tf.math.reduce_sum(diffabs, axis=-1) 375 376 return bm 377 378 ######################### 379 # Keras layer functions 380 ######################### 381 382 def build(self, input_shape): 383 """Build layer and check dimensions.""" 384 # assert rank must be two 385 tf.debugging.assert_greater_equal(len(input_shape), 2) 386 387 self._n = input_shape[-1] 388 389 divisible = tf.math.floormod(self._n, self._conv_n) 390 assert divisible==0, 'length of codeword should be divisible by \ 391 number of output bits per symbol.' 392 393 self._num_syms = int(self._n*self._coderate_desired) 394 395 self._num_term_syms = self._mu if self.terminate else 0 396 self._k = self._num_syms - self._num_term_syms 397 398 def call(self, inputs): 399 """ 400 Viterbi decoding function. 401 402 inputs is the (noisy) codeword tensor where the last dimension should 403 equal n. All the leading dimensions are assumed as batch dimensions. 404 405 """ 406 LARGEDIST = 2.**20 # pylint: disable=invalid-name 407 408 tf.debugging.assert_type(inputs, tf.float32, 409 message="input must be tf.float32.") 410 if self._method == 'hard': 411 inputs = tf.math.floormod(tf.cast(inputs, tf.int32),2) 412 elif self._method == 'soft_llr': 413 inputs = -1. * inputs 414 415 inputs = tf.cast(inputs, tf.float32) 416 417 output_shape = inputs.get_shape().as_list() 418 y_resh = tf.reshape(inputs, [-1, self._n]) 419 output_shape[0] = -1 420 if self._return_info_bits: 421 output_shape[-1] = self._k # assign k to the last dimension 422 else: 423 output_shape[-1] = self._n 424 # Branch metrics matrix for a given y 425 bm_mat = self._bmcalc(y_resh) 426 427 init_cm_np = np.full((self._ns,), LARGEDIST) 428 init_cm_np[0] = 0.0 429 prev_cm_ = tf.convert_to_tensor(init_cm_np, dtype=tf.float32) 430 prev_cm = tf.tile(prev_cm_[None,:], [tf.shape(y_resh)[0], 1]) 431 432 cm_ta, tb_ta = self._update_fwd(prev_cm, bm_mat) 433 434 cm = tf.transpose(cm_ta.stack(), perm=[1,2,0]) 435 tb = tf.transpose(tb_ta.stack(),perm=[1,2,0]) 436 del cm_ta, tb_ta 437 438 zero_st = tf.zeros((tf.shape(y_resh)[0], 1), tf.int32) 439 opt_path = self._optimal_path(cm, tb) 440 opt_path = tf.concat((zero_st, opt_path), axis=1) 441 del cm, tb 442 msghat, cwhat = self._op_bits_path(opt_path) 443 if self._return_info_bits: 444 msghat = msghat[...,:self._k] 445 output = tf.cast(msghat, self.output_dtype) 446 else: 447 output = tf.cast(cwhat, self.output_dtype) 448 output_reshaped = tf.reshape(output, output_shape) 449 450 return output_reshaped 451 452 453 class BCJRDecoder(Layer): 454 # pylint: disable=line-too-long 455 r"""BCJRDecoder(encoder=None, gen_poly=None, rate=1/2, constraint_length=3, rsc=False, terminate=False, hard_out=True, algorithm='map', output_dtype=tf.float32, **kwargs) 456 457 Implements the BCJR decoding algorithm [BCJR]_ that returns an 458 estimate of the information bits for a noisy convolutional codeword. 459 Takes as input either channel LLRs or a tuple 460 (channel LLRs, apriori LLRs). Returns an estimate of the information 461 bits, either output LLRs ( ``hard_out`` = `False`) or hard decoded 462 bits ( ``hard_out`` = `True`), respectively. 463 464 The class inherits from the Keras layer class and can be used as layer in 465 a Keras model. 466 467 Parameters 468 ---------- 469 encoder: :class:`~sionna.fec.conv.encoding.ConvEncoder` 470 If ``encoder`` is provided as input, the following input parameters 471 are not required and will be ignored: ``gen_poly``, ``rate``, 472 ``constraint_length``, ``rsc``, ``terminate``. They will be inferred 473 from the ``encoder`` object itself. If ``encoder`` is `None`, the 474 above parameters must be provided explicitly. 475 476 gen_poly: tuple 477 tuple of strings with each string being a 0, 1 sequence. If `None`, 478 ``rate`` and ``constraint_length`` must be provided. 479 480 rate: float 481 Valid values are 1/3 and 1/2. Only required if ``gen_poly`` is `None`. 482 483 constraint_length: int 484 Valid values are between 3 and 8 inclusive. Only required if 485 ``gen_poly`` is `None`. 486 487 rsc: boolean 488 Boolean flag indicating whether the encoder is recursive-systematic for 489 given generator polynomials. `True` indicates encoder is 490 recursive-systematic. `False` indicates encoder is feed-forward non-systematic. 491 492 terminate: boolean 493 Boolean flag indicating whether the codeword is terminated. 494 `True` indicates codeword is terminated to all-zero state. 495 `False` indicates codeword is not terminated. 496 497 hard_out: boolean 498 Boolean flag indicating whether to output hard or soft decisions on 499 the decoded information vector. 500 `True` implies a hard-decoded information vector of 0/1's as output. 501 `False` implies output is decoded LLR's of the information. 502 503 algorithm: str 504 Defaults to `map`. Indicates the implemented BCJR algorithm, 505 where `map` denotes the exact MAP algorithm, `log` indicates the 506 exact MAP implementation, but in log-domain, and 507 `maxlog` indicates the approximated MAP implementation in log-domain, 508 where :math:`\log(e^{a}+e^{b}) \sim \max(a,b)`. 509 510 output_dtype: tf.DType 511 Defaults to tf.float32. Defines the output datatype of the layer. 512 513 Input 514 ----- 515 llr_ch or (llr_ch, llr_a) : 516 Tensor or Tuple: 517 518 llr_ch: [...,n], tf.float32 519 2+D tensor containing the (noisy) channel 520 LLRs, where `n` denotes the codeword length 521 522 llr_a: [...,k], tf.float32 523 2+D tensor containing the a priori information of each information bit. 524 Implicitly assumed to be 0 if only ``llr_ch`` is provided. 525 526 Output 527 ------ 528 : tf.float32 529 2+D tensor of shape `[...,coderate*n]` containing the estimates of the 530 information bit tensor 531 532 """ 533 534 def __init__(self, 535 encoder=None, 536 gen_poly=None, 537 rate=1/2, 538 constraint_length=3, 539 rsc=False, 540 terminate=False, 541 hard_out=True, 542 algorithm='map', 543 output_dtype=tf.float32, 544 **kwargs): 545 546 super().__init__(**kwargs) 547 if encoder is not None: 548 self._gen_poly = encoder.gen_poly 549 self._trellis = encoder.trellis 550 self._terminate = encoder.terminate 551 else: 552 if gen_poly is not None: 553 assert all(isinstance(poly, str) for poly in gen_poly), \ 554 "Each polynomial must be a string." 555 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \ 556 "Each polynomial must be of same length." 557 assert all(all( 558 char in ['0','1'] for char in poly) for poly in gen_poly),\ 559 "Each polynomial must be a string of 0's and 1's." 560 self._gen_poly = gen_poly 561 else: 562 valid_rates = (1/2, 1/3) 563 valid_constraint_length = (3, 4, 5, 6, 7, 8) 564 565 assert constraint_length in valid_constraint_length, \ 566 "Constraint length must be between 3 and 8." 567 assert rate in valid_rates, \ 568 "Rate must be 1/3 or 1/2." 569 self._gen_poly = polynomial_selector(rate, constraint_length) 570 571 # init Trellis parameters 572 self._trellis = Trellis(self.gen_poly, rsc=rsc) 573 self._terminate = terminate 574 575 valid_algorithms = ['map', 'log', 'maxlog'] 576 assert algorithm in valid_algorithms, \ 577 "algorithm must be one of map, log or maxlog" 578 579 self._coderate_desired = 1/len(self._gen_poly) 580 self._mu = len(self._gen_poly[0])-1 581 582 self._num_term_bits = None 583 self._num_term_syms = None 584 585 # conv_k denotes number of input bit streams 586 # can only be 1 in current implementation 587 self._conv_k = self._trellis.conv_k 588 assert self._conv_k == 1 589 self._mu = self._trellis._mu 590 # conv_n denotes number of output bits for conv_k input bits 591 self._conv_n = self._trellis.conv_n 592 593 # Length of Info-bit vector. Equal to _num_syms if terminate=False, 594 # else < _num_syms 595 self._k = None 596 # Length of Turbo codeword, including termination bits 597 self._n = None 598 # num_syms denote number of encoding periods or state transitions. 599 self._num_syms = None 600 601 self._ni = 2**self._conv_k 602 self._no = 2**self._conv_n 603 self._ns = self._trellis.ns 604 605 self._hard_out = hard_out 606 self._algorithm = algorithm 607 608 self._output_dtype = output_dtype 609 self.ipst_op_idx, self.ipst_ip_idx = self._mask_by_tonode() 610 611 ######################################### 612 # Public methods and properties 613 ######################################### 614 615 @property 616 def gen_poly(self): 617 """Generator polynomial used by the encoder""" 618 return self._gen_poly 619 620 @property 621 def coderate(self): 622 """Rate of the code used in the encoder""" 623 if self.terminate and self._n is None: 624 print("Note that, due to termination, the true coderate is lower "\ 625 "than the returned design rate. "\ 626 "The exact true rate is dependent on the value of n and "\ 627 "hence cannot be computed before the first call().") 628 self._coderate = self._coderate_desired 629 elif self.terminate and self._n is not None: 630 k = self._coderate_desired*self._n - self._mu 631 self._coderate = k/self._n 632 return self._coderate 633 634 @property 635 def trellis(self): 636 """Trellis object used during encoding""" 637 return self._trellis 638 639 @property 640 def terminate(self): 641 """Indicates if the encoder is terminated during codeword generation""" 642 return self._terminate 643 644 @property 645 def k(self): 646 """Number of information bits per codeword""" 647 if self._k is None: 648 print("Note: The value of k cannot be computed before the first " \ 649 "call().") 650 return self._k 651 652 @property 653 def n(self): 654 """Number of codeword bits""" 655 if self._n is None: 656 print("Note: The value of n cannot be computed before the first " \ 657 "call().") 658 return self._n 659 660 ######################### 661 # Utility functions 662 ######################### 663 664 def _mask_by_tonode(self): 665 """ 666 Assume i->j a valid state transition given info-bit b & emits symbol k 667 returns following two _ns x _no matrices, each element of shape (2,). 668 - st_op_idx: jth row contains (i,k) tuples 669 - st_ip_idx: jth row contains (i,b) tuples 670 671 When applied as tf.gather on a _ns x _no matrix, the output is 672 matrix sorted by next_state. 673 674 For e.g., tf.gather when applied on "input" (shape _ns x _no), with mask 675 - st_op_idx: gathers input[i][k] in row j, 676 - st_ip_idx: gathers input[i][b] in row j. 677 """ 678 679 cnst = self._ns * self._ni 680 from_nodes_vec = tf.reshape(self._trellis.from_nodes,(cnst,)) 681 op_idx = tf.reshape(self._trellis.op_by_tonode, (cnst,)) 682 st_op_idx = tf.transpose(tf.stack([from_nodes_vec, op_idx])) 683 st_op_idx = tf.reshape(st_op_idx[None,:,:],(self._ns, self._ni, 2)) 684 685 ip_idx = tf.reshape(self._trellis.ip_by_tonode, (cnst,)) 686 st_ip_idx = tf.transpose(tf.stack([from_nodes_vec, ip_idx])) 687 st_ip_idx = tf.reshape(st_ip_idx[None,:,:],(self._ns, self._ni, 2)) 688 689 return st_op_idx, st_ip_idx 690 691 def _bmcalc(self, llr_in): 692 """ 693 Calculate branch gamma metrics for a given noisy codeword tensor. 694 For each time period t, _bmcalc computes the "distance" of symbol 695 vector y[t] from each possible output symbol i.e., 696 (2*Eb/N0)* sum_i x_y*y_i for i=1,2,...,conv_n 697 698 The above metric is used in calculation of gamma. 699 If the input is llr, which is nothing but 2*Eb*y/N0. 700 """ 701 op_bits = np.stack( 702 [int2bin(op, self._conv_n) for op in range(self._no)]) 703 op_mat = tf.cast(tf.tile(op_bits, [1, self._num_syms]), tf.float32) 704 op_mat = tf.expand_dims(op_mat, axis=0) 705 llr_in = tf.expand_dims(llr_in, axis=1) 706 op_mat_sign = 1. - 2. * op_mat 707 708 llr_sign = tf.math.multiply(llr_in, op_mat_sign) 709 half_llr_sign = tf.reshape(0.5 * llr_sign, 710 (-1, self._no, self._num_syms, self._conv_n)) 711 712 if self._algorithm in ['log', 'maxlog']: 713 bm = tf.math.reduce_sum(half_llr_sign, axis=-1) 714 else: 715 bm = tf.math.exp(tf.math.reduce_sum(half_llr_sign, axis=-1)) 716 717 return bm 718 719 def _initialize(self, llr_ch): 720 if self._algorithm in ['log', 'maxlog']: 721 init_vals = -np.inf, 0.0 722 else: 723 init_vals = 0.0, 1.0 724 alpha_init_np = np.full((self._ns,), init_vals[0]) 725 alpha_init_np[0] = init_vals[1] 726 727 beta_init_np = alpha_init_np 728 if not self._terminate: 729 eq_prob = 1./self._ns 730 if self._algorithm in ['log', 'maxlog']: 731 eq_prob = np.log(eq_prob) 732 beta_init_np = np.full((self._ns,), eq_prob) 733 734 alpha_init = tf.convert_to_tensor(alpha_init_np, dtype=tf.float32) 735 alpha_init = tf.tile(alpha_init[None,:], [tf.shape(llr_ch)[0], 1]) 736 beta_init = tf.convert_to_tensor(beta_init_np, dtype=tf.float32) 737 beta_init = tf.tile(beta_init[None,:], [tf.shape(llr_ch)[0], 1]) 738 return alpha_init, beta_init 739 740 def _update_fwd(self, alph_init, bm_mat, llr): 741 """ 742 Run forward update from time t=0 to t=k-1. 743 At each time t, computes alpha_t using alpha_t-1 and gamma_t. 744 745 Returns tensor array of alpha_t, t-0,1,2...,k-1 746 """ 747 alph_ta = tf.TensorArray(tf.float32, size=self._num_syms+1, 748 dynamic_size=False, clear_after_read=False) 749 alph_prev = tf.cast(alph_init, tf.float32) 750 751 # (bs, _Ns, _ni, 2) matrix 752 ipst_ip_mask = tf.tile( 753 self.ipst_ip_idx[None,:],[tf.shape(alph_init)[0],1,1,1]) 754 # (bs, _Ns, _ni) matrix, by from state 755 op_mask = tf.tile(self.trellis.op_by_fromnode[None,:,:], 756 [tf.shape(alph_init)[0],1,1]) 757 ipbit_mat = tf.tile(tf.range(self._ni)[None, None, :], 758 [tf.shape(alph_init)[0], self._ns, 1]) 759 ipbitsign_mat = 1. - 2. * tf.cast(ipbit_mat, tf.float32) 760 alph_ta = alph_ta.write(0, alph_prev) 761 for t in tf.range(self._num_syms): 762 bm_t = bm_mat[..., t] 763 llr_t = 0.5 * llr[...,t][:, None,None] 764 765 bm_byfromst = tf.gather(bm_t, op_mask, batch_dims=1) 766 signed_half_llr = tf.math.multiply( 767 tf.tile(llr_t,[1, self._ns, self._ni]), ipbitsign_mat) 768 if self._algorithm in ['log', 'maxlog']: 769 llr_byfromst = signed_half_llr 770 gamma_byfromst = llr_byfromst + bm_byfromst 771 alph_gam_prod = gamma_byfromst + alph_prev[:,:,None] 772 else: 773 llr_byfromst = tf.math.exp(signed_half_llr) 774 gamma_byfromst = tf.multiply(llr_byfromst, bm_byfromst) 775 alph_gam_prod = tf.math.multiply(gamma_byfromst, 776 alph_prev[:,:,None]) 777 778 alphgam_bytost = tf.gather_nd(alph_gam_prod, 779 ipst_ip_mask, 780 batch_dims=1) 781 if self._algorithm =='map': 782 alph_t = tf.math.reduce_sum(alphgam_bytost, axis=-1) 783 alph_t_sum = tf.reduce_sum(alph_t, axis=-1) 784 alph_t = tf.divide(alph_t, tf.tile(alph_t_sum[:,None],[1,self._ns])) 785 elif self._algorithm == 'log': 786 alph_t = tf.math.reduce_logsumexp(alphgam_bytost, axis=-1) 787 else: # self._algorithm = 'maxlog' 788 alph_t = tf.math.reduce_max(alphgam_bytost, axis=-1) 789 790 alph_prev = alph_t 791 alph_ta = alph_ta.write(t+1, alph_t) 792 return alph_ta 793 794 def _update_bwd(self, beta_init, bm_mat, llr, alpha_ta): 795 """ 796 Run backward update from time t=k-1 to t=0. 797 At each time t, computes beta_t-1 using beta_t and gamma_t. 798 799 Returns llr for information bits for t=0,1,...,k-1 800 """ 801 802 beta_next = beta_init 803 llr_op_ta = tf.TensorArray(tf.float32, 804 size=self._num_syms, 805 dynamic_size=False, 806 clear_after_read=False) 807 beta_next = tf.cast(beta_next, tf.float32) 808 809 # (bs, _Ns, _ni) matrix, by from state 810 op_mask = tf.tile(self.trellis.op_by_fromnode[None,:,:], 811 [tf.shape(beta_init)[0],1,1]) 812 tonode_mask = tf.tile(self.trellis.to_nodes[None,:,:], 813 [tf.shape(beta_init)[0], 1, 1]) 814 815 ipbit_mat = tf.tile(tf.range(self._ni)[None, None, :], 816 [tf.shape(beta_init)[0], self._ns, 1]) 817 ipbitsign_mat = 1.0 - 2.0 * tf.cast(ipbit_mat, tf.float32) 818 819 for t in tf.range(self._num_syms-1, -1, -1): 820 bm_t = bm_mat[..., t] 821 llr_t = 0.5 * llr[...,t][:, None,None] 822 signed_half_llr = tf.math.multiply( 823 tf.tile(llr_t,[1, self._ns, self._ni]), ipbitsign_mat) 824 bm_byfromst = tf.gather(bm_t, op_mask, batch_dims=1) 825 826 if self._algorithm in ['log', 'maxlog']: 827 llr_byfromst = signed_half_llr 828 gamma_byfromst = tf.math.add(llr_byfromst, bm_byfromst) 829 else: 830 llr_byfromst = tf.math.exp(signed_half_llr) 831 gamma_byfromst = tf.multiply(llr_byfromst, bm_byfromst) 832 833 beta_bytonode = tf.gather(beta_next, tonode_mask, batch_dims=1) 834 835 if self._algorithm not in ['log', 'maxlog']: 836 beta_gam_prod = tf.math.multiply(gamma_byfromst, beta_bytonode) 837 beta_t = tf.math.reduce_sum(beta_gam_prod, axis=-1) 838 beta_t_sum = tf.reduce_sum(beta_t, axis=-1) 839 beta_t = tf.divide(beta_t, tf.tile(beta_t_sum[:,None],[1,self._ns])) 840 elif self._algorithm == 'log': 841 beta_gam_prod = gamma_byfromst + beta_bytonode 842 beta_t = tf.math.reduce_logsumexp(beta_gam_prod, axis=-1, keepdims=False) 843 else: #self._algorithm = 'maxlog' 844 beta_gam_prod = gamma_byfromst + beta_bytonode 845 beta_t = tf.math.reduce_max(beta_gam_prod, axis=-1) 846 847 alph_t = alpha_ta.read(t) 848 if self._algorithm not in ['log', 'maxlog']: 849 llr_op_t0 = tf.math.multiply( 850 tf.math.multiply(alph_t, gamma_byfromst[...,0]), 851 beta_bytonode[...,0]) 852 llr_op_t1 = tf.math.multiply( 853 tf.math.multiply(alph_t,gamma_byfromst[...,1]), 854 beta_bytonode[...,1]) 855 llr_op_t = tf.math.log(tf.divide(tf.reduce_sum(llr_op_t0, axis=-1), 856 tf.reduce_sum(llr_op_t1,axis=-1))) 857 else: 858 llr_op_t0 = alph_t + gamma_byfromst[...,0] + beta_bytonode[...,0] 859 llr_op_t1 = alph_t + gamma_byfromst[...,1] + beta_bytonode[...,1] 860 if self._algorithm == 'log': 861 llr_op_t = tf.math.subtract( 862 tf.math.reduce_logsumexp(llr_op_t0, axis=-1), 863 tf.math.reduce_logsumexp(llr_op_t1, axis=-1)) 864 else: 865 llr_op_t = tf.math.subtract( 866 tf.math.reduce_max(llr_op_t0, axis=-1), 867 tf.math.reduce_max(llr_op_t1, axis=-1)) 868 869 llr_op_ta = llr_op_ta.write(t, llr_op_t) 870 beta_next = beta_t 871 872 llr_op = tf.transpose(llr_op_ta.stack()) 873 return llr_op 874 875 ######################### 876 # Keras layer functions 877 ######################### 878 879 def build(self, input_shape): 880 """Build layer and check dimensions.""" 881 # assert rank must be two 882 tf.debugging.assert_greater_equal(len(input_shape), 2) 883 884 if isinstance(input_shape, tf.TensorShape): 885 self._n = input_shape[-1] 886 else: 887 self._n = input_shape[0][-1] 888 889 self._num_syms = int(self._n*self._coderate_desired) 890 891 self._num_term_syms = self._mu if self._terminate else 0 892 self._num_term_bits = int(self._num_term_syms/self._coderate_desired) 893 894 self._k = self._num_syms - self._num_term_syms 895 896 def call(self, inputs): 897 """ 898 BCJR decoding function. 899 inputs is the (noisy) codeword tensor where the last dimension should 900 equal n. All the leading dimensions are assumed as batch dimensions. 901 """ 902 if isinstance(inputs, (tuple, list)): 903 assert(len(inputs)) == 2 904 llr_ch, llr_apr = inputs 905 else: 906 tf.debugging.assert_greater(tf.rank(inputs), 1) 907 llr_ch = inputs 908 llr_apr = None 909 910 tf.debugging.assert_type(llr_ch, 911 tf.float32, 912 message="input must be tf.float32.") 913 914 output_shape = llr_ch.get_shape().as_list() 915 916 # allow different codeword lengths in eager mode 917 if output_shape[-1] != self._n: 918 if isinstance(inputs, (tuple, list)): 919 self.build((inputs[0].get_shape(), 920 inputs[1].get_shape())) 921 else: 922 self.build(llr_ch.get_shape().as_list()) 923 924 output_shape[0] = -1 925 output_shape[-1] = self._k # assign k to the last dimension 926 llr_ch = tf.reshape(llr_ch, [-1, self._n]) 927 928 if llr_apr is None: 929 llr_apr = tf.zeros((tf.shape(llr_ch)[0], self._num_syms), 930 dtype=tf.float32) 931 llr_ch = -1. * llr_ch 932 llr_apr = -1. * llr_apr 933 934 # Branch metrics matrix for a given y 935 bm_mat = self._bmcalc(llr_ch) 936 alpha_init, beta_init = self._initialize(llr_ch) 937 938 alph_ta = self._update_fwd(alpha_init, bm_mat, llr_apr) 939 llr_op = self._update_bwd(beta_init, bm_mat, llr_apr, alph_ta) 940 941 msghat = -1. * llr_op[...,:self._k] 942 if self._hard_out: # hard decide decoder output if required 943 msghat = tf.less(0.0, msghat) 944 msghat = tf.cast(msghat, self._output_dtype) 945 msghat_reshaped = tf.reshape(msghat, output_shape) 946 947 return msghat_reshaped