discrete_channel.py (21873B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for discrete channel models""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from sionna import config 10 from sionna.utils import expand_to_rank 11 12 class BinaryMemorylessChannel(Layer): 13 # pylint: disable=line-too-long 14 r"""BinaryMemorylessChannel(return_llrs=False, bipolar_input=False, llr_max=100., dtype=tf.float32, **kwargs) 15 16 Discrete binary memory less channel with (possibly) asymmetric bit flipping 17 probabilities. 18 19 Inputs bits are flipped with probability :math:`p_\text{b,0}` and 20 :math:`p_\text{b,1}`, respectively. 21 22 .. figure:: ../figures/BMC_channel.png 23 :align: center 24 25 This layer supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` 26 inputs (:math:`x \in \{-1, 1\}`). 27 28 If activated, the channel directly returns log-likelihood ratios (LLRs) 29 defined as 30 31 .. math:: 32 \ell = 33 \begin{cases} 34 \operatorname{log} \frac{p_{b,1}}{1-p_{b,0}}, \qquad \text{if} \, y=0 \\ 35 \operatorname{log} \frac{1-p_{b,1}}{p_{b,0}}, \qquad \text{if} \, y=1 \\ 36 \end{cases} 37 38 The error probability :math:`p_\text{b}` can be either scalar or a 39 tensor (broadcastable to the shape of the input). This allows 40 different erasure probabilities per bit position. In any case, its last 41 dimension must be of length 2 and is interpreted as :math:`p_\text{b,0}` and 42 :math:`p_\text{b,1}`. 43 44 This class inherits from the Keras `Layer` class and can be used as layer in 45 a Keras model. 46 47 Parameters 48 ---------- 49 50 return_llrs: bool 51 Defaults to `False`. If `True`, the layer returns log-likelihood ratios 52 instead of binary values based on ``pb``. 53 54 bipolar_input : bool, False 55 Defaults to `False`. If `True`, the expected input is given as 56 :math:`\{-1,1\}` instead of :math:`\{0,1\}`. 57 58 llr_max: tf.float 59 Defaults to 100. Defines the clipping value of the LLRs. 60 61 dtype : tf.DType 62 Defines the datatype for internal calculations and the output 63 dtype. Defaults to `tf.float32`. 64 65 Input 66 ----- 67 (x, pb) : 68 Tuple: 69 70 x : [...,n], tf.float32 71 Input sequence to the channel consisting of binary values :math:`\{0,1\} 72 ` or :math:`\{-1,1\}`, respectively. 73 74 pb : [...,2], tf.float32 75 Error probability. Can be a tuple of two scalars or of any 76 shape that can be broadcasted to the shape of ``x``. It has an 77 additional last dimension which is interpreted as :math:`p_\text{b,0}` 78 and :math:`p_\text{b,1}`. 79 80 Output 81 ------- 82 : [...,n], tf.float32 83 Output sequence of same length as the input ``x``. If 84 ``return_llrs`` is `False`, the output is ternary where a `-1` and 85 `0` indicate an erasure for the binary and bipolar input, 86 respectively. 87 """ 88 89 def __init__(self, return_llrs=False, bipolar_input=False, llr_max=100.,dtype=tf.float32, **kwargs): 90 91 super().__init__(dtype=dtype,**kwargs) 92 93 assert isinstance(return_llrs, bool), "return_llrs must be bool." 94 self._return_llrs = return_llrs 95 96 assert isinstance(bipolar_input, bool), "bipolar_input must be bool." 97 self._bipolar_input = bipolar_input 98 99 assert llr_max>=0., "llr_max must be a positive scalar value." 100 self._llr_max = tf.cast(llr_max, dtype=self.dtype) 101 102 if self._return_llrs: 103 assert dtype in (tf.float16, tf.float32, tf.float64),\ 104 "LLR outputs require non-integer dtypes." 105 else: 106 if self._bipolar_input: 107 assert dtype in (tf.float16, tf.float32, tf.float64, 108 tf.int8, tf.int16, tf.int32, tf.int64),\ 109 "Only, signed dtypes are supported for bipolar inputs." 110 else: 111 assert dtype in (tf.float16, tf.float32, tf.float64, 112 tf.uint8, tf.uint16, tf.uint32, tf.uint64, 113 tf.int8, tf.int16, tf.int32, tf.int64),\ 114 "Only, real-valued dtypes are supported." 115 116 self._check_input = True # check input for consistency (i.e., binary) 117 118 self._eps = 1e-9 # small additional term for numerical stability 119 self._temperature = tf.constant(0.1, tf.float32) # for Gumble-softmax 120 121 ######################################### 122 # Public methods and properties 123 ######################################### 124 125 @property 126 def llr_max(self): 127 """Maximum value used for LLR calculations.""" 128 return self._llr_max 129 130 @llr_max.setter 131 def llr_max(self, value): 132 """Maximum value used for LLR calculations.""" 133 assert value>=0, 'llr_max cannot be negative.' 134 self._llr_max = tf.cast(value, dtype=tf.float32) 135 136 @property 137 def temperature(self): 138 """Temperature for Gumble-softmax trick.""" 139 return self._temperature 140 141 @temperature.setter 142 def temperature(self, value): 143 """Temperature for Gumble-softmax trick.""" 144 assert value>=0, 'temperature cannot be negative.' 145 self._temperature = tf.cast(value, dtype=tf.float32) 146 147 ######################### 148 # Utility methods 149 ######################### 150 151 def _check_inputs(self, x): 152 """Check input x for consistency, i.e., verify 153 that all values are binary of bipolar values.""" 154 x = tf.cast(x, tf.float32) 155 if self._check_input: 156 if self._bipolar_input: # allow -1 and 1 for bipolar inputs 157 values = (tf.constant(-1, x.dtype),tf.constant(1, x.dtype)) 158 else: # allow 0,1 for binary input 159 values = (tf.constant(0, x.dtype),tf.constant(1, x.dtype)) 160 tf.debugging.assert_equal( 161 tf.reduce_min(tf.cast(tf.logical_or(tf.equal(x, values[0]), 162 tf.equal(x, values[1])), x.dtype)), 163 tf.constant(1, x.dtype), 164 "Input must be binary.") 165 # input datatype consistency should be only evaluated once 166 self._check_input = False 167 168 @tf.custom_gradient 169 def _custom_xor(self, a, b): 170 """Straight through estimator for XOR.""" 171 def grad(upstream): 172 """identity in backward direction""" 173 return upstream, upstream 174 # xor in forward path 175 # use module for "exotic" dtypes 176 if self.dtype in (tf.uint8, tf.uint16, tf.uint32, tf.uint64, tf.int8, tf.int16, tf.int32, tf.int64): 177 z = tf.math.mod(a+b, tf.constant(2, self.dtype)) 178 else: # use abs for float dtypes 179 z = tf.abs(a - b) 180 181 return z, grad 182 183 @tf.custom_gradient 184 def _ste_binarizer(self, x): 185 """Straight through binarizer to quantize bits to int values.""" 186 def grad(upstream): 187 """identity in backward direction""" 188 return upstream 189 # hard-decide in forward path 190 z = tf.where(x<.5, 0., 1.) 191 return z, grad 192 193 def _sample_errors(self, pb, shape): 194 """Samples binary error vector with given error probability e. 195 This function is based on the Gumble-softmax "trick" to keep the 196 sampling differentiable.""" 197 198 # this implementation follows https://arxiv.org/pdf/1611.01144v5.pdf 199 # and https://arxiv.org/pdf/1906.07748.pdf 200 201 u1 = config.tf_rng.uniform(shape=shape, 202 minval=0., 203 maxval=1., 204 dtype=tf.float32) 205 u2 = config.tf_rng.uniform(shape=shape, 206 minval=0., 207 maxval=1., 208 dtype=tf.float32) 209 u = tf.stack((u1, u2), axis=-1) 210 211 # sample Gumble distribution 212 q = - tf.math.log(- tf.math.log(u + self._eps) + self._eps) 213 p = tf.stack((pb,1-pb), axis=-1) 214 p = expand_to_rank(p, tf.rank(q), axis=0) 215 p = tf.broadcast_to(p, tf.shape(q)) 216 a = (tf.math.log(p + self._eps) + q) / self._temperature 217 218 # apply softmax 219 e_cat = tf.nn.softmax(a) 220 221 # binarize final values via straight-through estimator 222 return self._ste_binarizer(e_cat[...,0]) # only take first class 223 224 ######################### 225 # Keras layer functions 226 ######################### 227 228 def build(self, input_shapes): 229 """Verify correct input shapes""" 230 231 pb_shapes = input_shapes[1] 232 # allow tuple of scalars as alternative input 233 if isinstance(pb_shapes, (tuple, list)): 234 if not len(pb_shapes)==2: 235 raise ValueError("Last dim of pb must be of length 2.") 236 else: 237 if len(pb_shapes)>0: 238 if not pb_shapes[-1]==2: 239 raise ValueError("Last dim of pb must be of length 2.") 240 else: 241 raise ValueError("Last dim of pb must be of length 2.") 242 243 def call(self, inputs): 244 """Apply discrete binary memoryless channel to inputs.""" 245 246 x, pb = inputs 247 248 # allow pb to be a tuple of two scalars 249 if isinstance(pb, (tuple, list)): 250 pb0 = pb[0] 251 pb1 = pb[1] 252 else: 253 pb0 = pb[...,0] 254 pb1 = pb[...,1] 255 256 # clip for numerical stability 257 pb0 = tf.cast(pb0, tf.float32) # Gumble requires float dtypes 258 pb1 = tf.cast(pb1, tf.float32) # Gumble requires float dtypes 259 pb0 = tf.clip_by_value(pb0, 0., 1.) 260 pb1 = tf.clip_by_value(pb1, 0., 1.) 261 262 # check x for consistency (binary, bipolar) 263 self._check_inputs(x) 264 265 e0 = self._sample_errors(pb0, tf.shape(x)) 266 e1 = self._sample_errors(pb1, tf.shape(x)) 267 268 if self._bipolar_input: 269 neutral_element = tf.constant(-1, dtype=x.dtype) 270 else: 271 neutral_element = tf.constant(0, dtype=x.dtype) 272 273 # mask e0 and e1 with input such that e0 only applies where x==0 274 e = tf.where(x==neutral_element, e0, e1) 275 e = tf.cast(e, x.dtype) 276 277 if self._bipolar_input: 278 # flip signs for bipolar case 279 y = x * (-2*e + 1) 280 else: 281 # XOR for binary case 282 y = self._custom_xor(x, e) 283 284 # if LLRs should be returned 285 if self._return_llrs: 286 if not self._bipolar_input: 287 y = 2 * y - 1 # transform to bipolar 288 289 # Remark: Sionna uses the logit definition log[p(x=1)/p(x=0)] 290 y0 = - (tf.math.log(pb1 + self._eps) 291 - tf.math.log(1 - pb0 - self._eps)) 292 y1 = (tf.math.log(1 - pb1 - self._eps) 293 - tf.math.log(pb0 + self._eps)) 294 # multiply by y to keep gradient 295 y = tf.cast(tf.where(y==1, y1, y0), dtype=y.dtype) * y 296 # and clip output llrs 297 y = tf.clip_by_value(y, -self._llr_max, self._llr_max) 298 299 return y 300 301 class BinarySymmetricChannel(BinaryMemorylessChannel): 302 # pylint: disable=line-too-long 303 r"""BinarySymmetricChannel(return_llrs=False, bipolar_input=False, llr_max=100., dtype=tf.float32, **kwargs) 304 305 Discrete binary symmetric channel which randomly flips bits with probability 306 :math:`p_\text{b}`. 307 308 .. figure:: ../figures/BSC_channel.png 309 :align: center 310 311 This layer supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` 312 inputs (:math:`x \in \{-1, 1\}`). 313 314 If activated, the channel directly returns log-likelihood ratios (LLRs) 315 defined as 316 317 .. math:: 318 \ell = 319 \begin{cases} 320 \operatorname{log} \frac{p_{b}}{1-p_{b}}, \qquad \text{if}\, y=0 \\ 321 \operatorname{log} \frac{1-p_{b}}{p_{b}}, \qquad \text{if}\, y=1 \\ 322 \end{cases} 323 where :math:`y` denotes the binary output of the channel. 324 325 The bit flipping probability :math:`p_\text{b}` can be either a scalar or a 326 tensor (broadcastable to the shape of the input). This allows 327 different bit flipping probabilities per bit position. 328 329 This class inherits from the Keras `Layer` class and can be used as layer in 330 a Keras model. 331 332 Parameters 333 ---------- 334 335 return_llrs: bool 336 Defaults to `False`. If `True`, the layer returns log-likelihood ratios 337 instead of binary values based on ``pb``. 338 339 bipolar_input : bool, False 340 Defaults to `False`. If `True`, the expected input is given as {-1,1} 341 instead of {0,1}. 342 343 llr_max: tf.float 344 Defaults to 100. Defines the clipping value of the LLRs. 345 346 dtype : tf.DType 347 Defines the datatype for internal calculations and the output 348 dtype. Defaults to `tf.float32`. 349 350 Input 351 ----- 352 (x, pb) : 353 Tuple: 354 355 x : [...,n], tf.float32 356 Input sequence to the channel. 357 358 pb : tf.float32 359 Bit flipping probability. Can be a scalar or of any shape that 360 can be broadcasted to the shape of ``x``. 361 362 Output 363 ------- 364 : [...,n], tf.float32 365 Output sequence of same length as the input ``x``. If 366 ``return_llrs`` is `False`, the output is binary and otherwise 367 soft-values are returned. 368 """ 369 370 def __init__(self, return_llrs=False, bipolar_input=False, llr_max=100., dtype=tf.float32, **kwargs): 371 372 super().__init__(return_llrs=return_llrs, 373 bipolar_input=bipolar_input, 374 llr_max=llr_max, 375 dtype=dtype, 376 **kwargs) 377 378 ######################### 379 # Keras layer functions 380 ######################### 381 382 def build(self, input_shapes): 383 """Verify correct input shapes""" 384 pass # nothing to verify here 385 386 def call(self, inputs): 387 """Apply discrete binary symmetric channel, i.e., randomly flip 388 bits with probability pb.""" 389 390 x, pb = inputs 391 392 # the BSC is implemented by calling the DMC with symmetric pb 393 pb = tf.cast(pb, x.dtype) 394 pb = tf.stack((pb, pb), axis=-1) 395 y = super().call((x, pb)) 396 397 return y 398 399 class BinaryZChannel(BinaryMemorylessChannel): 400 # pylint: disable=line-too-long 401 r"""BinaryZChannel(return_llrs=False, bipolar_input=False, llr_max=100., dtype=tf.float32, **kwargs) 402 403 Layer that implements the binary Z-channel. 404 405 In the Z-channel, transmission errors only occur for the transmission of 406 second input element (i.e., if a `1` is transmitted) with error probability 407 probability :math:`p_\text{b}` but the first element is always correctly 408 received. 409 410 .. figure:: ../figures/Z_channel.png 411 :align: center 412 413 414 This layer supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` 415 inputs (:math:`x \in \{-1, 1\}`). 416 417 If activated, the channel directly returns log-likelihood ratios (LLRs) 418 defined as 419 420 .. math:: 421 \ell = 422 \begin{cases} 423 \operatorname{log} \left( p_b \right), \qquad \text{if} \, y=0 \\ 424 \infty, \qquad \qquad \text{if} \, y=1 \\ 425 \end{cases} 426 assuming equal probable inputs :math:`P(X=0) = P(X=1) = 0.5`. 427 428 The error probability :math:`p_\text{b}` can be either a scalar or a 429 tensor (broadcastable to the shape of the input). This allows 430 different error probabilities per bit position. 431 432 This class inherits from the Keras `Layer` class and can be used as layer in 433 a Keras model. 434 435 Parameters 436 ---------- 437 438 return_llrs: bool 439 Defaults to `False`. If `True`, the layer returns log-likelihood ratios 440 instead of binary values based on ``pb``. 441 442 bipolar_input : bool, False 443 Defaults to `False`. If True, the expected input is given as {-1,1} 444 instead of {0,1}. 445 446 llr_max: tf.float 447 Defaults to 100. Defines the clipping value of the LLRs. 448 449 dtype : tf.DType 450 Defines the datatype for internal calculations and the output 451 dtype. Defaults to `tf.float32`. 452 453 Input 454 ----- 455 (x, pb) : 456 Tuple: 457 458 x : [...,n], tf.float32 459 Input sequence to the channel. 460 461 pb : tf.float32 462 Error probability. Can be a scalar or of any shape that can be 463 broadcasted to the shape of ``x``. 464 465 Output 466 ------- 467 : [...,n], tf.float32 468 Output sequence of same length as the input ``x``. If 469 ``return_llrs`` is `False`, the output is binary and otherwise 470 soft-values are returned. 471 """ 472 473 def __init__(self, return_llrs=False, bipolar_input=False, llr_max=100.,dtype=tf.float32, **kwargs): 474 475 super().__init__(return_llrs=return_llrs, 476 bipolar_input=bipolar_input, 477 llr_max=llr_max, 478 dtype=dtype, 479 **kwargs) 480 481 ######################### 482 # Keras layer functions 483 ######################### 484 485 def build(self, input_shapes): 486 """Verify correct input shapes""" 487 pass # nothing to verify here 488 489 def call(self, inputs): 490 """Apply discrete binary symmetric channel, i.e., randomly flip 491 bits with probability pb.""" 492 493 x, pb = inputs 494 495 # the Z is implemented by calling the DMC with p(1|0)=0 496 pb = tf.cast(pb, x.dtype) 497 pb = tf.stack((tf.zeros_like(pb), pb), axis=-1) 498 y = super().call((x, pb)) 499 500 return y 501 502 503 class BinaryErasureChannel(BinaryMemorylessChannel): 504 # pylint: disable=line-too-long 505 r"""BinaryErasureChannel(return_llrs=False, bipolar_input=False, llr_max=100., dtype=tf.float32, **kwargs) 506 507 Binary erasure channel (BEC) where a bit is either correctly received 508 or erased. 509 510 In the binary erasure channel, bits are always correctly received or erased 511 with erasure probability :math:`p_\text{b}`. 512 513 .. figure:: ../figures/BEC_channel.png 514 :align: center 515 516 This layer supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` 517 inputs (:math:`x \in \{-1, 1\}`). 518 519 If activated, the channel directly returns log-likelihood ratios (LLRs) 520 defined as 521 522 .. math:: 523 \ell = 524 \begin{cases} 525 -\infty, \qquad \text{if} \, y=0 \\ 526 0, \qquad \quad \,\, \text{if} \, y=? \\ 527 \infty, \qquad \quad \text{if} \, y=1 \\ 528 \end{cases} 529 530 The erasure probability :math:`p_\text{b}` can be either a scalar or a 531 tensor (broadcastable to the shape of the input). This allows 532 different erasure probabilities per bit position. 533 534 Please note that the output of the BEC is ternary. Hereby, `-1` indicates an 535 erasure for the binary configuration and `0` for the bipolar mode, 536 respectively. 537 538 This class inherits from the Keras `Layer` class and can be used as layer in 539 a Keras model. 540 541 Parameters 542 ---------- 543 544 return_llrs: bool 545 Defaults to `False`. If `True`, the layer returns log-likelihood ratios 546 instead of binary values based on ``pb``. 547 548 bipolar_input : bool, False 549 Defaults to `False`. If `True`, the expected input is given as {-1,1} 550 instead of {0,1}. 551 552 llr_max: tf.float 553 Defaults to 100. Defines the clipping value of the LLRs. 554 555 dtype : tf.DType 556 Defines the datatype for internal calculations and the output 557 dtype. Defaults to `tf.float32`. 558 559 Input 560 ----- 561 (x, pb) : 562 Tuple: 563 564 x : [...,n], tf.float32 565 Input sequence to the channel. 566 567 pb : tf.float32 568 Erasure probability. Can be a scalar or of any shape that can be 569 broadcasted to the shape of ``x``. 570 571 Output 572 ------- 573 : [...,n], tf.float32 574 Output sequence of same length as the input ``x``. If 575 ``return_llrs`` is `False`, the output is ternary where each `-1` 576 and each `0` indicate an erasure for the binary and bipolar input, 577 respectively. 578 """ 579 580 def __init__(self, return_llrs=False, bipolar_input=False, llr_max=100.,dtype=tf.float32, **kwargs): 581 582 super().__init__(return_llrs=return_llrs, 583 bipolar_input=bipolar_input, 584 llr_max=llr_max, 585 dtype=dtype, 586 **kwargs) 587 588 # also exclude uints, as -1 indicator for erasures does not exist 589 assert dtype in (tf.float16, tf.float32, tf.float64, 590 tf.int8, tf.int16, tf.int32, tf.int64),\ 591 "Unsigned integers are currently not supported." 592 593 ######################### 594 # Keras layer functions 595 ######################### 596 597 def build(self, input_shapes): 598 """Verify correct input shapes""" 599 pass # nothing to verify here 600 601 def call(self, inputs): 602 """Apply erasure channel to inputs.""" 603 604 x, pb = inputs 605 606 # clip for numerical stability 607 pb = tf.cast(pb, tf.float32) # Gumble requires float dtypes 608 pb = tf.clip_by_value(pb, 0., 1.) 609 610 # check x for consistency (binary, bipolar) 611 self._check_inputs(x) 612 613 # sample erasure pattern 614 e = self._sample_errors(pb, tf.shape(x)) 615 616 # if LLRs should be returned 617 # remark: the Sionna logit definition is llr = log[p(x=1)/p(x=0)] 618 if self._return_llrs: 619 if not self._bipolar_input: 620 x = 2 * x -1 621 x *= tf.cast(self._llr_max, x.dtype) # calculate llrs 622 623 # erase positions by setting llrs to 0 624 y = tf.where(e==1, tf.constant(0, x.dtype), x) 625 else: # ternary outputs 626 # the erasure indicator depends on the operation mode 627 if self._bipolar_input: 628 erased_element = tf.constant(0, dtype=x.dtype) 629 else: 630 erased_element = tf.constant(-1, dtype=x.dtype) 631 632 y = tf.where(e==0, x, erased_element) 633 return y