metrics.py (6646B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Userful metrics for the Sionna library.""" 6 7 import tensorflow as tf 8 from tensorflow.keras.metrics import Metric 9 from tensorflow.keras.losses import BinaryCrossentropy 10 11 12 class BitwiseMutualInformation(Metric): 13 """BitwiseMutualInformation(name="bitwise_mutual_information", **kwargs) 14 15 Computes the bitwise mutual information between bits and LLRs. 16 17 This class implements a Keras metric for the bitwise mutual information 18 between a tensor of bits and LLR (logits). 19 20 Input 21 ----- 22 bits : tf.float32 23 A tensor of arbitrary shape filled with ones and zeros. 24 25 llr : tf.float32 26 A tensor of the same shape as ``bits`` containing logits. 27 28 Output 29 ------ 30 : tf.float32 31 A scalar, the bit-wise mutual information. 32 33 """ 34 def __init__(self, name="bitwise_mutual_information", **kwargs): 35 super().__init__(name, **kwargs) 36 self.bmi = self.add_weight(name="bmi", initializer="zeros", 37 dtype=tf.float32) 38 self.counter = self.add_weight(name="counter", initializer="zeros") 39 self.bce = BinaryCrossentropy(from_logits=True) 40 41 def update_state(self, bits, llr): 42 self.counter.assign_add(1) 43 self.bmi.assign_add(1-self.bce(bits, llr)/tf.math.log(2.)) 44 45 def result(self): 46 return tf.cast(tf.math.divide_no_nan(self.bmi, self.counter), 47 dtype=tf.float32) 48 49 def reset_state(self): 50 self.bmi.assign(0.0) 51 self.counter.assign(0.0) 52 53 class BitErrorRate(Metric): 54 """BitErrorRate(name="bit_error_rate", **kwargs) 55 56 Computes the average bit error rate (BER) between two binary tensors. 57 58 This class implements a Keras metric for the bit error rate 59 between two tensors of bits. 60 61 Input 62 ----- 63 b : tf.float32 64 A tensor of arbitrary shape filled with ones and 65 zeros. 66 67 b_hat : tf.float32 68 A tensor of the same shape as ``b`` filled with 69 ones and zeros. 70 71 Output 72 ------ 73 : tf.float32 74 A scalar, the BER. 75 """ 76 def __init__(self, name="bit_error_rate", **kwargs): 77 super().__init__(name, **kwargs) 78 self.ber = self.add_weight(name="ber", 79 initializer="zeros", 80 dtype=tf.float64) 81 self.counter = self.add_weight(name="counter", 82 initializer="zeros", 83 dtype=tf.float64) 84 85 def update_state(self, b, b_hat): 86 self.counter.assign_add(1) 87 self.ber.assign_add(compute_ber(b, b_hat)) 88 89 def result(self): 90 #cast results of computer_ber for compatibility with tf.float32 91 return tf.cast(tf.math.divide_no_nan(self.ber, self.counter), 92 dtype=tf.float32) 93 94 def reset_state(self): 95 self.ber.assign(0.0) 96 self.counter.assign(0.0) 97 98 def compute_ber(b, b_hat): 99 """Computes the bit error rate (BER) between two binary tensors. 100 101 Input 102 ----- 103 b : tf.float32 104 A tensor of arbitrary shape filled with ones and 105 zeros. 106 107 b_hat : tf.float32 108 A tensor of the same shape as ``b`` filled with 109 ones and zeros. 110 111 Output 112 ------ 113 : tf.float64 114 A scalar, the BER. 115 """ 116 ber = tf.not_equal(b, b_hat) 117 ber = tf.cast(ber, tf.float64) # tf.float64 to suport large batch-sizes 118 return tf.reduce_mean(ber) 119 120 def compute_ser(s, s_hat): 121 """Computes the symbol error rate (SER) between two integer tensors. 122 123 Input 124 ----- 125 s : tf.int 126 A tensor of arbitrary shape filled with integers indicating 127 the symbol indices. 128 129 s_hat : tf.int 130 A tensor of the same shape as ``s`` filled with integers indicating 131 the estimated symbol indices. 132 133 Output 134 ------ 135 : tf.float64 136 A scalar, the SER. 137 """ 138 ser = tf.not_equal(s, s_hat) 139 ser = tf.cast(ser, tf.float64) # tf.float64 to suport large batch-sizes 140 return tf.reduce_mean(ser) 141 142 def compute_bler(b, b_hat): 143 """Computes the block error rate (BLER) between two binary tensors. 144 145 A block error happens if at least one element of ``b`` and ``b_hat`` 146 differ in one block. The BLER is evaluated over the last dimension of 147 the input, i. e., all elements of the last dimension are considered to 148 define a block. 149 150 This is also sometimes referred to as `word error rate` or `frame error 151 rate`. 152 153 Input 154 ----- 155 b : tf.float32 156 A tensor of arbitrary shape filled with ones and 157 zeros. 158 159 b_hat : tf.float32 160 A tensor of the same shape as ``b`` filled with 161 ones and zeros. 162 163 Output 164 ------ 165 : tf.float64 166 A scalar, the BLER. 167 """ 168 bler = tf.reduce_any(tf.not_equal(b, b_hat), axis=-1) 169 bler = tf.cast(bler, tf.float64) # tf.float64 to suport large batch-sizes 170 return tf.reduce_mean(bler) 171 172 def count_errors(b, b_hat): 173 """Counts the number of bit errors between two binary tensors. 174 175 Input 176 ----- 177 b : tf.float32 178 A tensor of arbitrary shape filled with ones and 179 zeros. 180 181 b_hat : tf.float32 182 A tensor of the same shape as ``b`` filled with 183 ones and zeros. 184 185 Output 186 ------ 187 : tf.int64 188 A scalar, the number of bit errors. 189 """ 190 errors = tf.not_equal(b,b_hat) 191 errors = tf.cast(errors, tf.int64) 192 return tf.reduce_sum(errors) 193 194 def count_block_errors(b, b_hat): 195 """Counts the number of block errors between two binary tensors. 196 197 A block error happens if at least one element of ``b`` and ``b_hat`` 198 differ in one block. The BLER is evaluated over the last dimension of 199 the input, i. e., all elements of the last dimension are considered to 200 define a block. 201 202 This is also sometimes referred to as `word error rate` or `frame error 203 rate`. 204 205 Input 206 ----- 207 b : tf.float32 208 A tensor of arbitrary shape filled with ones and 209 zeros. 210 211 b_hat : tf.float32 212 A tensor of the same shape as ``b`` filled with 213 ones and zeros. 214 215 Output 216 ------ 217 : tf.int64 218 A scalar, the number of block errors. 219 """ 220 errors = tf.reduce_any(tf.not_equal(b,b_hat), axis=-1) 221 errors = tf.cast(errors, tf.int64) 222 return tf.reduce_sum(errors) 223