decoding.py (17046B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers for decoding of linear codes.""" 6 7 import tensorflow as tf 8 import numpy as np 9 import scipy as sp # for sparse H matrix computations 10 from tensorflow.keras.layers import Layer 11 from sionna.fec.utils import pcm2gm, int_mod_2, make_systematic 12 from sionna.utils import hard_decisions 13 import itertools 14 15 class OSDecoder(Layer): 16 # pylint: disable=line-too-long 17 r"""OSDecoder(enc_mat=None, t=0, is_pcm=False, encoder=None, dtype=tf.float32, **kwargs) 18 19 Ordered statistics decoding (OSD) for binary, linear block codes. 20 21 This layer implements the OSD algorithm as proposed in [Fossorier]_ and, 22 thereby, approximates maximum likelihood decoding for a sufficiently large 23 order :math:`t`. The algorithm works for arbitrary linear block codes, but 24 has a high computational complexity for long codes. 25 26 The algorithm consists of the following steps: 27 28 1. Sort LLRs according to their reliability and apply the same column 29 permutation to the generator matrix. 30 31 2. Bring the permuted generator matrix into its systematic form 32 (so-called *most-reliable basis*). 33 34 3. Hard-decide and re-encode the :math:`k` most reliable bits and 35 discard the remaining :math:`n-k` received positions. 36 37 4. Generate all possible error patterns up to :math:`t` errors in the 38 :math:`k` most reliable positions find the most likely codeword within 39 these candidates. 40 41 This implementation of the OSD algorithm uses the LLR-based distance metric 42 from [Stimming_LLR_OSD]_ which simplifies the handling of higher-order 43 modulation schemes. 44 45 The class inherits from the Keras layer class and can be used as layer in a 46 Keras model. 47 48 Parameters 49 ---------- 50 enc_mat : [k, n] or [n-k, n], ndarray 51 Binary generator matrix of shape `[k, n]`. If ``is_pcm`` is 52 True, ``enc_mat`` is interpreted as parity-check matrix of shape 53 `[n-k, n]`. 54 55 t : int 56 Order of the OSD algorithm 57 58 is_pcm: bool 59 Defaults to False. If True, ``enc_mat`` is interpreted as parity-check 60 matrix. 61 62 encoder: Layer 63 Keras layer that implements a FEC encoder. 64 If not None, ``enc_mat`` will be ignored and the code as specified by he 65 encoder is used to initialize OSD. 66 67 dtype: tf.DType 68 Defaults to `tf.float32`. Defines the datatype for the output dtype. 69 70 Input 71 ----- 72 llrs_ch: [...,n], tf.float32 73 2+D tensor containing the channel logits/llr values. 74 75 Output 76 ------ 77 : [...,n], tf.float32 78 2+D Tensor of same shape as ``llrs_ch`` containing 79 binary hard-decisions of all codeword bits. 80 81 Note 82 ---- 83 OS decoding is of high complexity and is only feasible for small values of 84 :math:`t` as :math:`{n \choose t}` patterns must be evaluated. The 85 advantage of OSD is that it works for arbitrary linear block codes and 86 provides an estimate of the expected ML performance for sufficiently large 87 :math:`t`. However, for some code families, more efficient decoding 88 algorithms with close to ML performance exist which can exploit certain 89 code specific properties. Examples of such decoders are the 90 :class:`~sionna.fec.conv.ViterbiDecoder` algorithm for convolutional codes 91 or the :class:`~sionna.fec.polar.decoding.PolarSCLDecoder` for Polar codes 92 (for a sufficiently large list size). 93 94 It is recommended to run the decoder in XLA mode as it 95 significantly reduces the memory complexity. 96 """ 97 98 def __init__(self, 99 enc_mat=None, 100 t=0, 101 is_pcm=False, 102 encoder=None, 103 dtype=tf.float32, 104 **kwargs): 105 106 super().__init__(dtype=dtype, **kwargs) 107 108 assert isinstance(is_pcm, bool), 'is_pcm must be bool.' 109 110 self._llr_max = 100. # internal clipping value for llrs 111 112 if enc_mat is not None: 113 # check that gm is binary 114 if isinstance(enc_mat, np.ndarray): 115 assert np.array_equal(enc_mat, enc_mat.astype(bool)), \ 116 'PC matrix must be binary.' 117 elif isinstance(enc_mat, sp.sparse.csr_matrix): 118 assert np.array_equal(enc_mat.data, enc_mat.data.astype(bool)),\ 119 'PC matrix must be binary.' 120 elif isinstance(enc_mat, sp.sparse.csc_matrix): 121 assert np.array_equal(enc_mat.data, enc_mat.data.astype(bool)),\ 122 'PC matrix must be binary.' 123 else: 124 raise TypeError("Unsupported dtype of pcm.") 125 126 if dtype not in (tf.float16, tf.float32, tf.float64): 127 raise ValueError( 128 'dtype must be {tf.float16, tf.float32, tf.float64}.') 129 130 assert (int(t)==t), "t must be int." 131 self._t = int(t) 132 133 if encoder is not None: 134 # test that encoder is already initialized (relevant for conv codes) 135 if encoder.k is None: 136 raise AttributeError("It seems as if the encoder is not "\ 137 "initialized or has no attribute k.") 138 # encode identity matrix to get k basis vectors of the code 139 u = tf.expand_dims(tf.eye(encoder.k), axis=0) 140 # encode and remove batch_dim 141 self._gm = tf.cast(tf.squeeze(encoder(u), axis=0), self.dtype) 142 else: 143 assert (enc_mat is not None),\ 144 "enc_mat cannot be None if no encoder is provided." 145 if is_pcm: 146 gm = pcm2gm(enc_mat) 147 else: 148 # check if gm is of full rank (raise error otherwise) 149 make_systematic(enc_mat) 150 gm = enc_mat 151 self._gm = tf.constant(gm, dtype=self.dtype) 152 153 self._k = self._gm.shape[0] 154 self._n = self._gm.shape[1] 155 156 # init error patterns 157 num_patterns = self._num_error_patterns(self._n, self._t) 158 159 # storage/computational complexity scales with n 160 num_symbols = num_patterns * self._n 161 if num_symbols>1e9: # number still to be optimized 162 print(f"Note: Required memory complexity is large for the "\ 163 f"given code parameters and t={t}. Please consider small " \ 164 f"batch-sizes to keep the inference complexity small and " \ 165 f"activate XLA mode if possible." ) 166 if num_symbols>1e11: # number still to be optimized 167 raise ResourceWarning("Due to its high complexity, OSD is not " \ 168 "feasible for the selected parameters. " \ 169 "Please consider using a smaller value for t.") 170 171 # pre-compute all error patterns 172 self._err_patterns = [] 173 for t_i in range(1, t+1): 174 self._err_patterns.append(self._gen_error_patterns(self._k, t_i)) 175 176 ######################################### 177 # Public methods and properties 178 ######################################### 179 180 @property 181 def gm(self): 182 """Generator matrix of the code""" 183 return self._gm 184 185 @property 186 def n(self): 187 """Codeword length""" 188 return self._n 189 190 @property 191 def k(self): 192 """Number of information bits per codeword""" 193 return self._k 194 195 @property 196 def t(self): 197 """Order of the OSD algorithm""" 198 return self._t 199 200 ######################### 201 # Utility methods 202 ######################### 203 204 def _num_error_patterns(self, n, t): 205 r"""Returns number of possible error patterns for t errors in n 206 positions, i.e., calculates :math:`{n \choose t}`. 207 208 Input 209 ----- 210 n: int 211 length of vector. 212 213 t: int 214 number of errors. 215 """ 216 return sp.special.comb(n, t, exact=True, repetition=False) 217 218 def _gen_error_patterns(self, n, t): 219 r"""Returns list of all possible error patterns for t errors in n 220 positions. 221 222 Input 223 ----- 224 n: int 225 Length of vector. 226 227 t: int 228 Number of errors. 229 230 Output 231 ------ 232 : [num_patterns, t], tf.int32 233 Tensor of size `num_patterns`=:math:`{n \choose t}` containing the 234 t error indices. 235 """ 236 237 err_patterns = [] 238 for p in itertools.combinations(range(n), t): 239 err_patterns.append(p) 240 241 return tf.constant(err_patterns) 242 243 def _get_dist(self, llr, c_hat): 244 """Distance function used for ML candidate selection. 245 246 Currently, the distance metric from Polar decoding [Stimming_LLR_OSD]_ 247 literature is implemented. 248 249 Input 250 ----- 251 llr: [bs, n], tf.float32 252 Received llrs of the channel observations. 253 254 c_hat: [bs, num_cand, n], tf.float32 255 Candidate codewords for which the distance to ``llr`` shall be 256 evaluated. 257 258 Output 259 ------ 260 : [bs, num_cand], tf.float32 261 Distance between ``llr`` and ``c_hat`` for each of the `num_cand` 262 codeword candidates. 263 264 Reference 265 --------- 266 [Stimming_LLR_OSD] Alexios Balatsoukas-Stimming, Mani Bastani Parizi, 267 Andreas Burg, "LLR-Based Successive Cancellation List Decoding 268 of Polar Codes." IEEE Trans Signal Processing, 2015. 269 """ 270 271 # broadcast llr to all codeword candidates 272 llr = tf.expand_dims(llr, axis=1) 273 llr_sign = llr * (-2.*c_hat + 1.) # apply BPSK mapping 274 275 d = tf.math.log(1. + tf.exp(llr_sign)) 276 return tf.reduce_mean(d, axis=2) 277 278 def _find_min_dist(self, llr_ch, ep, gm_mrb, c): 279 r"""Find error pattern which leads to minimum distance. 280 281 Input 282 ----- 283 llr_ch: [bs, n], tf.float32 284 Channel observations as llrs after mrb sorting. 285 286 ep: [num_patterns, t], tf.int32 287 Tensor of size `num_patterns`=:math:`{n \choose t}` containing the 288 t error indices. 289 290 gm_mrb: [bs, k, n] tf.float32 291 Most reliable basis for each batch example. 292 293 c: [bs, n], tf.float32 294 Most reliable base codeword. 295 296 Output 297 ------ 298 : [bs], tf.float32 299 Distance of the most likely codeword to ``llr_ch`` after testing all 300 ``ep`` error patterns. 301 302 : [bs, n], tf.float32 303 The most likely codeword after testing against all ``ep`` error 304 patterns. 305 """ 306 307 # generate all test candidates for each possible error pattern 308 e = tf.gather(gm_mrb, ep, axis=1) 309 e = tf.reduce_sum(e, axis=2) 310 e += tf.expand_dims(c, axis=1) # add to mrb codeword 311 c_cand = int_mod_2(e) # apply modulo-2 operation 312 313 # calculate distance for each candidate 314 # where c_cand has shape [bs, num_patterns, n] 315 d = self._get_dist(llr_ch, c_cand) 316 317 # find candidate index with smallest metric 318 idx = tf.argmin(d, axis=1) 319 c_hat = tf.gather(c_cand, idx, batch_dims=1) 320 d = tf.gather(d, idx, batch_dims=1) 321 return d, c_hat 322 323 def _find_mrb(self, gm): 324 """Find most reliable basis for all generator matrices in batch. 325 326 Input 327 ----- 328 gm: [bs, k, n] tf.float32 329 Generator matrix for each batch example. 330 331 Output 332 ------ 333 gm_mrb: [bs, k, n] tf.float32 334 Most reliable basis in systematic form for each batch example. 335 336 idx_sort: [bs, n] tf.int64 337 Indices of column permutations applied during mrb calculation. 338 """ 339 340 bs = tf.shape(gm)[0] 341 s = gm.shape 342 idx_pivot = tf.TensorArray(tf.int64, self._k, dynamic_size=False) 343 344 # bring gm in systematic form (by so-called pivot method) 345 for idx_c in tf.range(self._k): 346 347 # ensure shape to avoid XLA incompatibility with TF2.11 in tf.range 348 gm = tf.ensure_shape(gm, s) 349 350 # find pivot (i.e., first pos with index 1) 351 idx_p = tf.argmax(gm[:, idx_c, :], axis=-1) 352 353 # store pivot position 354 idx_pivot = idx_pivot.write(idx_c, idx_p) 355 356 # and eliminate the column in all other rows 357 r = tf.gather(gm, idx_p, batch_dims=1, axis=-1) 358 359 # ignore idx_c row itself by adding all-zero row 360 rz = tf.zeros((bs, 1), dtype=self.dtype) 361 r = tf.concat([r[:,:idx_c], rz , r[:,idx_c+1:]], axis=1) 362 363 # mask is zero at all rows where pivot position of this row is zero 364 mask = tf.tile(tf.expand_dims(r, axis=-1), (1, 1, self._n)) 365 gm_off = tf.expand_dims(gm[:,idx_c,:], axis=1) 366 367 # update all row in parallel 368 gm = int_mod_2(gm + mask * gm_off) # account for binary operations 369 370 # pivot positions 371 idx_pivot = tf.transpose(idx_pivot.stack()) 372 373 # find non-pivot positions (i.e., all indices that are not part of 374 # idx_pivot) 375 376 # solution 1: sets.difference() does not support XLA (unknown shapes) 377 #idx_parity = tf.sets.difference(idx_range, idx_pivot) 378 #idx_parity = tf.sparse.to_dense(idx_parity) 379 #idx_pivot = tf.reshape(idx_pivot, (-1, self._n)) # ensure shape 380 381 # solution 2: add large offset to pivot indices and sorting gives the 382 # indices of interest 383 idx_range = tf.tile(tf.expand_dims( 384 tf.range(self._n, dtype=tf.int64), axis=0), 385 (bs, 1)) 386 # large value to be added to irrelevant indices 387 updates = self._n * tf.ones((bs, self._k), tf.int64) 388 389 # generate indices for tf.scatter_nd_add 390 s = tf.shape(idx_pivot, tf.int64) 391 ii, _ = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij') 392 idx_updates = tf.stack([ii, idx_pivot], axis=-1) 393 394 # add large value to pivot positions 395 idx = tf.tensor_scatter_nd_add(idx_range, idx_updates, updates) 396 397 # sort and slice first n-k indices (equals parity positions) 398 idx_parity = tf.cast(tf.argsort(idx)[:,:self._n-self._k], tf.int64) 399 400 idx_sort = tf.concat([idx_pivot, idx_parity], axis=1) 401 402 # permute gm according to indices idx_sort 403 gm = tf.gather(gm, idx_sort, batch_dims=1, axis=-1) 404 405 return gm, idx_sort 406 407 ######################### 408 # Keras layer functions 409 ######################### 410 411 def build(self, input_shape): 412 """Nothing to build, but check for valid shapes.""" 413 414 assert input_shape[-1]==self._n, "Invalid input shape." 415 416 def call(self, inputs): 417 r"""Applies ordered statistic decoding to inputs. 418 419 Remark: the decoder is implemented with llr definition 420 llr = p(x=1)/p(x=0). 421 """ 422 423 # flatten batch-dim 424 input_shape = tf.shape(inputs) 425 llr_ch = tf.reshape(inputs, (-1, self._n)) 426 llr_ch = tf.cast(llr_ch, self.dtype) 427 bs = tf.shape(llr_ch)[0] 428 429 # clip inputs 430 llr_ch = tf.clip_by_value(llr_ch, -self._llr_max, self._llr_max) 431 432 # step 1: sort LLRs 433 idx_sort = tf.argsort(tf.abs(llr_ch), direction="DESCENDING") 434 435 # permute gm per batch sample individually 436 gm = tf.broadcast_to(tf.expand_dims(self._gm, axis=0), 437 (bs, self._k,self._n)) 438 gm_sort = tf.gather(gm, idx_sort, batch_dims=1, axis=-1) 439 440 # step 2: Find most reliable basis (MRB) 441 gm_mrb, idx_mrb = self._find_mrb(gm_sort) 442 443 # apply corresponding mrb permutations 444 idx_sort = tf.gather(idx_sort, idx_mrb, batch_dims=1) 445 llr_sort = tf.gather(llr_ch, idx_sort, batch_dims=1) 446 447 # find inverse permutation for final output 448 idx_sort_inv = tf.argsort(idx_sort) 449 450 # hard-decide k most reliable positions and encode 451 u_hd = hard_decisions(llr_sort[:,0:self._k]) 452 u_hd = tf.expand_dims(u_hd, axis=1) 453 c = tf.squeeze(tf.matmul(u_hd, gm_mrb), axis=1) 454 c = int_mod_2(c) 455 456 # and search for most likely pattern 457 # _get_dist expects a list of candidates, thus expand_dims to [bs, 1, n] 458 d_best = self._get_dist(llr_sort, tf.expand_dims(c, axis=1)) 459 d_best = tf.squeeze(d_best, axis=1) 460 c_hat_best = c 461 462 # known in advance - can be unrolled 463 for ep in self._err_patterns: 464 # compute distance for all candidate codewords 465 d, c_hat = self._find_min_dist(llr_sort, ep, gm_mrb, c) 466 467 # select most likely candidate 468 ind = tf.expand_dims(d<d_best, axis=1) 469 c_hat_best = tf.where(ind, c_hat, c_hat_best) 470 d_best = tf.where(d<d_best, d, d_best) 471 472 # undo permutations for final codeword 473 c_hat_best = tf.gather(c_hat_best, idx_sort_inv, axis=1, batch_dims=1) 474 # input shape 475 c_hat = tf.reshape(c_hat_best, input_shape) 476 477 return c_hat