utils.py (8678B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Utility functions and layers for the Polar code package.""" 6 7 import numpy as np 8 import numbers 9 from numpy.core.numerictypes import issubdtype 10 import matplotlib.pyplot as plt 11 from scipy.special import comb 12 from importlib_resources import files, as_file 13 from . import codes # pylint: disable=relative-beyond-top-level 14 15 def generate_5g_ranking(k, n, sort=True): 16 """Returns information and frozen bit positions of the 5G Polar code 17 as defined in Tab. 5.3.1.2-1 in [3GPPTS38212]_ for given values of ``k`` 18 and ``n``. 19 20 Input 21 ----- 22 k: int 23 The number of information bit per codeword. 24 25 n: int 26 The desired codeword length. Must be a power of two. 27 28 sort: bool 29 Defaults to True. Indicates if the returned indices are 30 sorted. 31 32 Output 33 ------ 34 [frozen_pos, info_pos]: 35 List: 36 37 frozen_pos: ndarray 38 An array of ints of shape `[n-k]` containing the frozen 39 position indices. 40 41 info_pos: ndarray 42 An array of ints of shape `[k]` containing the information 43 position indices. 44 45 Raises 46 ------ 47 AssertionError 48 If ``k`` or ``n`` are not positve ints. 49 50 AssertionError 51 If ``sort`` is not bool. 52 53 AssertionError 54 If ``k`` or ``n`` are larger than 1024 55 56 AssertionError 57 If ``n`` is less than 32. 58 59 AssertionError 60 If the resulting coderate is invalid (`>1.0`). 61 62 AssertionError 63 If ``n`` is not a power of 2. 64 """ 65 #assert error if r>1 or k,n are negativ 66 assert isinstance(k, int), "k must be integer." 67 assert isinstance(n, int), "n must be integer." 68 assert isinstance(sort, bool), "sort must be bool." 69 assert k>-1, "k cannot be negative." 70 assert k<1025, "k cannot be larger than 1024." 71 assert n<1025, "n cannot be larger than 1024." 72 assert n>31, "n must be >=32." 73 assert n>=k, "Invalid coderate (>1)." 74 assert np.log2(n)==int(np.log2(n)), "n must be a power of 2." 75 76 # load the channel ranking from csv format in folder "codes" 77 source = files(codes).joinpath("polar_5G.csv") 78 with as_file(source) as codes.csv: 79 ch_order = np.genfromtxt(codes.csv, delimiter=";") 80 ch_order = ch_order.astype(int) 81 82 # find n smallest values of channel order (2nd row) 83 ind = np.argsort(ch_order[:,1]) 84 ch_order_sort = ch_order[ind,:] 85 # only consider the first n channels 86 ch_order_sort_n = ch_order_sort[0:n,:] 87 # and sort again according to reliability 88 ind_n = np.argsort(ch_order_sort_n[:,0]) 89 ch_order_n = ch_order_sort_n[ind_n,:] 90 91 # and calculate frozen/information positions for given n, k 92 # assume that pre_frozen_pos are already frozen (rate-matching) 93 frozen_pos = np.zeros(n-k) 94 info_pos = np.zeros(k) 95 #the n-k smallest positions of ch_order denote frozen pos. 96 for i in range(n-k): 97 frozen_pos[i] = ch_order_n[i,1] # 2. row yields index to freeze 98 for i in range(n-k, n): 99 info_pos[i-(n-k)] = ch_order_n[i,1] # 2. row yields index to freeze 100 101 # sort to have channels in ascending order 102 if sort: 103 info_pos = np.sort(info_pos) 104 frozen_pos = np.sort(frozen_pos) 105 106 return [frozen_pos.astype(int), info_pos.astype(int)] 107 108 def generate_polar_transform_mat(n_lift): 109 """Generate the polar transformation matrix (Kronecker product). 110 111 Input 112 ----- 113 n_lift: int 114 Defining the Kronecker power, i.e., how often is the kernel lifted. 115 116 Output 117 ------ 118 : ndarray 119 Array of `0s` and `1s` of shape `[2^n_lift , 2^n_lift]` containing 120 the Polar transformation matrix. 121 """ 122 123 assert int(n_lift)==n_lift, "n_lift must be integer" 124 assert n_lift>=0, "n_lift must be positive" 125 126 assert n_lift<12, "Warning: the resulting code length is large (=2^n_lift)." 127 128 gm = np.array([[1, 0],[ 1, 1]]) 129 130 gm_l = np.copy(gm) 131 for _ in range(n_lift-1): 132 gm_l_new = np.zeros([2*np.shape(gm_l)[0],2*np.shape(gm_l)[1]]) 133 for j in range(np.shape(gm_l)[0]): 134 for k in range(np.shape(gm_l)[1]): 135 gm_l_new[2*j:2*j+2, 2*k:2*k+2] = gm_l[j,k]*gm 136 gm_l = gm_l_new 137 return gm_l 138 139 def generate_rm_code(r, m): 140 """Generate frozen positions of the (r, m) Reed Muller (RM) code. 141 142 Input 143 ----- 144 r: int 145 The order of the RM code. 146 147 m: int 148 `log2` of the desired codeword length. 149 150 Output 151 ------ 152 [frozen_pos, info_pos, n, k, d_min]: 153 List: 154 155 frozen_pos: ndarray 156 An array of ints of shape `[n-k]` containing the frozen 157 position indices. 158 159 info_pos: ndarray 160 An array of ints of shape `[k]` containing the information 161 position indices. 162 163 n: int 164 Resulting codeword length 165 166 k: int 167 Number of information bits 168 169 d_min: int 170 Minimum distance of the code. 171 172 Raises 173 ------ 174 AssertionError 175 If ``r`` is larger than ``m``. 176 177 AssertionError 178 If ``r`` or ``m`` are not positive ints. 179 180 """ 181 assert isinstance(r, int), "r must be int." 182 assert isinstance(m, int), "m must be int." 183 assert r<=m, "order r cannot be larger than m." 184 assert r>=0, "r must be positive." 185 assert m>=0, "m must be positive." 186 187 n = 2**m 188 d_min = 2**(m-r) 189 190 # calc k to verify results 191 k = 0 192 for i in range(r+1): 193 k += int(comb(m,i)) 194 195 # select positions to freeze 196 # freeze all rows that have weight < m-r 197 w = np.zeros(n) 198 for i in range(n): 199 x_bin = np.binary_repr(i) 200 for x_i in x_bin: 201 w[i] += int(x_i) 202 frozen_vec = w < m-r 203 info_vec = np.invert(frozen_vec) 204 k_res = np.sum(info_vec) 205 frozen_pos = np.arange(n)[frozen_vec] 206 info_pos = np.arange(n)[info_vec] 207 208 # verify results 209 assert k_res==k, "Error: resulting k is inconsistent." 210 211 return frozen_pos, info_pos, n, k, d_min 212 213 214 def generate_dense_polar(frozen_pos, n, verbose=True): 215 """Generate *naive* (dense) Polar parity-check and generator matrix. 216 217 This function follows Lemma 1 in [Goala_LP]_ and returns a parity-check 218 matrix for Polar codes. 219 220 Note 221 ---- 222 The resulting matrix can be used for decoding with the 223 :class:`~sionna.fec.ldpc.LDPCBPDecoder` class. However, the resulting 224 parity-check matrix is (usually) not sparse and, thus, not suitable for 225 belief propagation decoding as the graph has many short cycles. 226 Please consider :class:`~sionna.fec.polar.PolarBPDecoder` for iterative 227 decoding over the encoding graph. 228 229 Input 230 ----- 231 frozen_pos: ndarray 232 Array of `int` defining the ``n-k`` indices of the frozen positions. 233 234 n: int 235 The codeword length. 236 237 verbose: bool 238 Defaults to True. If True, the code properties are printed. 239 240 Output 241 ------ 242 pcm: ndarray of `zeros` and `ones` of shape [n-k, n] 243 The parity-check matrix. 244 245 gm: ndarray of `zeros` and `ones` of shape [k, n] 246 The generator matrix. 247 248 """ 249 250 assert isinstance(n, numbers.Number), "n must be a number." 251 n = int(n) # n can be float (e.g. as result of n=k*r) 252 assert issubdtype(frozen_pos.dtype, int), "frozen_pos must \ 253 consist of ints." 254 assert len(frozen_pos)<=n, "Number of elements in frozen_pos cannot \ 255 be greater than n." 256 257 assert np.log2(n)==int(np.log2(n)), "n must be a power of 2." 258 259 k = n - len(frozen_pos) 260 261 # generate info positions 262 info_pos = np.setdiff1d(np.arange(n), frozen_pos) 263 assert k==len(info_pos), "Internal error: invalid " \ 264 "info_pos generated." 265 266 gm_mat = generate_polar_transform_mat(int(np.log2(n))) 267 268 gm_true = gm_mat[info_pos,:] 269 pcm = np.transpose(gm_mat[:,frozen_pos]) 270 271 if verbose: 272 print("Shape of the generator matrix: ", gm_true.shape) 273 print("Shape of the parity-check matrix: ", pcm.shape) 274 plt.spy(pcm) 275 276 # Verify result, i.e., check that H*G has an all-zero syndrome. 277 # Note: we have no proof that Lemma 1 holds for all possible 278 # frozen_positions. Thus, it seems to be better to verify the generated 279 # results individually. 280 s = np.mod(np.matmul(pcm, np.transpose(gm_true)),2) 281 assert np.sum(s)==0, "Non-zero syndrom for H*G'." 282 283 return pcm, gm_true