anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

metrics.py (6646B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Userful metrics for the Sionna library."""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.metrics import Metric
      9 from tensorflow.keras.losses import BinaryCrossentropy
     10 
     11 
     12 class BitwiseMutualInformation(Metric):
     13     """BitwiseMutualInformation(name="bitwise_mutual_information", **kwargs)
     14 
     15     Computes the bitwise mutual information between bits and LLRs.
     16 
     17     This class implements a Keras metric for the bitwise mutual information
     18     between a tensor of bits and LLR (logits).
     19 
     20     Input
     21     -----
     22         bits : tf.float32
     23             A tensor of arbitrary shape filled with ones and zeros.
     24 
     25         llr : tf.float32
     26             A tensor of the same shape as ``bits`` containing logits.
     27 
     28     Output
     29     ------
     30         : tf.float32
     31             A scalar, the bit-wise mutual information.
     32 
     33     """
     34     def __init__(self, name="bitwise_mutual_information", **kwargs):
     35         super().__init__(name, **kwargs)
     36         self.bmi = self.add_weight(name="bmi", initializer="zeros",
     37                                    dtype=tf.float32)
     38         self.counter = self.add_weight(name="counter", initializer="zeros")
     39         self.bce = BinaryCrossentropy(from_logits=True)
     40 
     41     def update_state(self, bits, llr):
     42         self.counter.assign_add(1)
     43         self.bmi.assign_add(1-self.bce(bits, llr)/tf.math.log(2.))
     44 
     45     def result(self):
     46         return tf.cast(tf.math.divide_no_nan(self.bmi, self.counter),
     47                        dtype=tf.float32)
     48 
     49     def reset_state(self):
     50         self.bmi.assign(0.0)
     51         self.counter.assign(0.0)
     52 
     53 class BitErrorRate(Metric):
     54     """BitErrorRate(name="bit_error_rate", **kwargs)
     55 
     56     Computes the average bit error rate (BER) between two binary tensors.
     57 
     58     This class implements a Keras metric for the bit error rate
     59     between two tensors of bits.
     60 
     61     Input
     62     -----
     63         b : tf.float32
     64             A tensor of arbitrary shape filled with ones and
     65             zeros.
     66 
     67         b_hat : tf.float32
     68             A tensor of the same shape as ``b`` filled with
     69             ones and zeros.
     70 
     71     Output
     72     ------
     73         : tf.float32
     74             A scalar, the BER.
     75     """
     76     def __init__(self, name="bit_error_rate", **kwargs):
     77         super().__init__(name, **kwargs)
     78         self.ber = self.add_weight(name="ber",
     79                                    initializer="zeros",
     80                                    dtype=tf.float64)
     81         self.counter = self.add_weight(name="counter",
     82                                        initializer="zeros",
     83                                        dtype=tf.float64)
     84 
     85     def update_state(self, b, b_hat):
     86         self.counter.assign_add(1)
     87         self.ber.assign_add(compute_ber(b, b_hat))
     88 
     89     def result(self):
     90         #cast results of computer_ber for compatibility with tf.float32
     91         return tf.cast(tf.math.divide_no_nan(self.ber, self.counter),
     92                        dtype=tf.float32)
     93 
     94     def reset_state(self):
     95         self.ber.assign(0.0)
     96         self.counter.assign(0.0)
     97 
     98 def compute_ber(b, b_hat):
     99     """Computes the bit error rate (BER) between two binary tensors.
    100 
    101     Input
    102     -----
    103         b : tf.float32
    104             A tensor of arbitrary shape filled with ones and
    105             zeros.
    106 
    107         b_hat : tf.float32
    108             A tensor of the same shape as ``b`` filled with
    109             ones and zeros.
    110 
    111     Output
    112     ------
    113         : tf.float64
    114             A scalar, the BER.
    115     """
    116     ber = tf.not_equal(b, b_hat)
    117     ber = tf.cast(ber, tf.float64) # tf.float64 to suport large batch-sizes
    118     return tf.reduce_mean(ber)
    119 
    120 def compute_ser(s, s_hat):
    121     """Computes the symbol error rate (SER) between two integer tensors.
    122 
    123     Input
    124     -----
    125         s : tf.int
    126             A tensor of arbitrary shape filled with integers indicating
    127             the symbol indices.
    128 
    129         s_hat : tf.int
    130             A tensor of the same shape as ``s`` filled with integers indicating
    131             the estimated symbol indices.
    132 
    133     Output
    134     ------
    135         : tf.float64
    136             A scalar, the SER.
    137     """
    138     ser = tf.not_equal(s, s_hat)
    139     ser = tf.cast(ser, tf.float64) # tf.float64 to suport large batch-sizes
    140     return tf.reduce_mean(ser)
    141 
    142 def compute_bler(b, b_hat):
    143     """Computes the block error rate (BLER) between two binary tensors.
    144 
    145     A block error happens if at least one element of ``b`` and ``b_hat``
    146     differ in one block. The BLER is evaluated over the last dimension of
    147     the input, i. e., all elements of the last dimension are considered to
    148     define a block.
    149 
    150     This is also sometimes referred to as `word error rate` or `frame error
    151     rate`.
    152 
    153     Input
    154     -----
    155         b : tf.float32
    156             A tensor of arbitrary shape filled with ones and
    157             zeros.
    158 
    159         b_hat : tf.float32
    160             A tensor of the same shape as ``b`` filled with
    161             ones and zeros.
    162 
    163     Output
    164     ------
    165         : tf.float64
    166             A scalar, the BLER.
    167     """
    168     bler = tf.reduce_any(tf.not_equal(b, b_hat), axis=-1)
    169     bler = tf.cast(bler, tf.float64) # tf.float64 to suport large batch-sizes
    170     return tf.reduce_mean(bler)
    171 
    172 def count_errors(b, b_hat):
    173     """Counts the number of bit errors between two binary tensors.
    174 
    175     Input
    176     -----
    177         b : tf.float32
    178             A tensor of arbitrary shape filled with ones and
    179             zeros.
    180 
    181         b_hat : tf.float32
    182             A tensor of the same shape as ``b`` filled with
    183             ones and zeros.
    184 
    185     Output
    186     ------
    187         : tf.int64
    188             A scalar, the number of bit errors.
    189     """
    190     errors = tf.not_equal(b,b_hat)
    191     errors = tf.cast(errors, tf.int64)
    192     return tf.reduce_sum(errors)
    193 
    194 def count_block_errors(b, b_hat):
    195     """Counts the number of block errors between two binary tensors.
    196 
    197     A block error happens if at least one element of ``b`` and ``b_hat``
    198     differ in one block. The BLER is evaluated over the last dimension of
    199     the input, i. e., all elements of the last dimension are considered to
    200     define a block.
    201 
    202     This is also sometimes referred to as `word error rate` or `frame error
    203     rate`.
    204 
    205     Input
    206     -----
    207         b : tf.float32
    208             A tensor of arbitrary shape filled with ones and
    209             zeros.
    210 
    211         b_hat : tf.float32
    212             A tensor of the same shape as ``b`` filled with
    213             ones and zeros.
    214 
    215     Output
    216     ------
    217         : tf.int64
    218             A scalar, the number of block errors.
    219     """
    220     errors = tf.reduce_any(tf.not_equal(b,b_hat), axis=-1)
    221     errors = tf.cast(errors, tf.int64)
    222     return tf.reduce_sum(errors)
    223