anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

tensors.py (11798B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Functions extending TensorFlow tensor operations"""
      6 
      7 import tensorflow as tf
      8 import sionna as sn
      9 
     10 def expand_to_rank(tensor, target_rank, axis=-1):
     11     """Inserts as many axes to a tensor as needed to achieve a desired rank.
     12 
     13     This operation inserts additional dimensions to a ``tensor`` starting at
     14     ``axis``, so that so that the rank of the resulting tensor has rank
     15     ``target_rank``. The dimension index follows Python indexing rules, i.e.,
     16     zero-based, where a negative index is counted backward from the end.
     17 
     18     Args:
     19         tensor : A tensor.
     20         target_rank (int) : The rank of the output tensor.
     21             If ``target_rank`` is smaller than the rank of ``tensor``,
     22             the function does nothing.
     23         axis (int) : The dimension index at which to expand the
     24                shape of ``tensor``. Given a ``tensor`` of `D` dimensions,
     25                ``axis`` must be within the range `[-(D+1), D]` (inclusive).
     26 
     27     Returns:
     28         A tensor with the same data as ``tensor``, with
     29         ``target_rank``- rank(``tensor``) additional dimensions inserted at the
     30         index specified by ``axis``.
     31         If ``target_rank`` <= rank(``tensor``), ``tensor`` is returned.
     32     """
     33     num_dims = tf.maximum(target_rank - tf.rank(tensor), 0)
     34     output = insert_dims(tensor, num_dims, axis)
     35 
     36     return output
     37 
     38 def flatten_dims(tensor, num_dims, axis):
     39     """
     40     Flattens a specified set of dimensions of a tensor.
     41 
     42     This operation flattens ``num_dims`` dimensions of a ``tensor``
     43     starting at a given ``axis``.
     44 
     45     Args:
     46         tensor : A tensor.
     47         num_dims (int): The number of dimensions
     48             to combine. Must be larger than two and less or equal than the
     49             rank of ``tensor``.
     50         axis (int): The index of the dimension from which to start.
     51 
     52     Returns:
     53         A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser
     54         dimensions, but the same number of elements.
     55     """
     56     msg = "`num_dims` must be >= 2"
     57     tf.debugging.assert_greater_equal(num_dims, 2, msg)
     58 
     59     msg = "`num_dims` must <= rank(`tensor`)"
     60     tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg)
     61 
     62     msg = "0<= `axis` <= rank(tensor)-1"
     63     tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg)
     64     tf.debugging.assert_greater_equal(axis, 0, msg)
     65 
     66     msg ="`num_dims`+`axis` <= rank(`tensor`)"
     67     tf.debugging.assert_less_equal(num_dims + axis, tf.rank(tensor), msg)
     68 
     69     if num_dims==len(tensor.shape):
     70         new_shape = [-1]
     71     elif axis==0:
     72         shape = tf.shape(tensor)
     73         new_shape = tf.concat([[-1], shape[axis+num_dims:]], 0)
     74     else:
     75         shape = tf.shape(tensor)
     76         flat_dim = tf.reduce_prod(tensor.shape[axis:axis+num_dims])
     77         new_shape = tf.concat([shape[:axis],
     78                                [flat_dim],
     79                                shape[axis+num_dims:]], 0)
     80 
     81     return tf.reshape(tensor, new_shape)
     82 
     83 def flatten_last_dims(tensor, num_dims=2):
     84     """
     85     Flattens the last `n` dimensions of a tensor.
     86 
     87     This operation flattens the last ``num_dims`` dimensions of a ``tensor``.
     88     It is a simplified version of the function ``flatten_dims``.
     89 
     90     Args:
     91         tensor : A tensor.
     92         num_dims (int): The number of dimensions
     93             to combine. Must be greater than or equal to two and less or equal
     94             than the rank of ``tensor``.
     95 
     96     Returns:
     97         A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser
     98         dimensions, but the same number of elements.
     99     """
    100     msg = "`num_dims` must be >= 2"
    101     tf.debugging.assert_greater_equal(num_dims, 2, msg)
    102 
    103     msg = "`num_dims` must <= rank(`tensor`)"
    104     tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg)
    105 
    106     if num_dims==len(tensor.shape):
    107         new_shape = [-1]
    108     else:
    109         shape = tf.shape(tensor)
    110         last_dim = tf.reduce_prod(tensor.shape[-num_dims:])
    111         new_shape = tf.concat([shape[:-num_dims], [last_dim]], 0)
    112 
    113     return tf.reshape(tensor, new_shape)
    114 
    115 def insert_dims(tensor, num_dims, axis=-1):
    116     """Adds multiple length-one dimensions to a tensor.
    117 
    118     This operation is an extension to TensorFlow`s ``expand_dims`` function.
    119     It inserts ``num_dims`` dimensions of length one starting from the
    120     dimension ``axis`` of a ``tensor``. The dimension
    121     index follows Python indexing rules, i.e., zero-based, where a negative
    122     index is counted backward from the end.
    123 
    124     Args:
    125         tensor : A tensor.
    126         num_dims (int) : The number of dimensions to add.
    127         axis : The dimension index at which to expand the
    128                shape of ``tensor``. Given a ``tensor`` of `D` dimensions,
    129                ``axis`` must be within the range `[-(D+1), D]` (inclusive).
    130 
    131     Returns:
    132         A tensor with the same data as ``tensor``, with ``num_dims`` additional
    133         dimensions inserted at the index specified by ``axis``.
    134     """
    135     msg = "`num_dims` must be nonnegative."
    136     tf.debugging.assert_greater_equal(num_dims, 0, msg)
    137 
    138     rank = tf.rank(tensor)
    139     msg = "`axis` is out of range `[-(D+1), D]`)"
    140     tf.debugging.assert_less_equal(axis, rank, msg)
    141     tf.debugging.assert_greater_equal(axis, -(rank+1), msg)
    142 
    143     axis = axis if axis>=0 else rank+axis+1
    144     shape = tf.shape(tensor)
    145     new_shape = tf.concat([shape[:axis],
    146                            tf.ones([num_dims], tf.int32),
    147                            shape[axis:]], 0)
    148     output = tf.reshape(tensor, new_shape)
    149 
    150     return output
    151 
    152 def split_dim(tensor, shape, axis):
    153     """Reshapes a dimension of a tensor into multiple dimensions.
    154 
    155     This operation splits the dimension ``axis`` of a ``tensor`` into
    156     multiple dimensions according to ``shape``.
    157 
    158     Args:
    159         tensor : A tensor.
    160         shape (list or TensorShape): The shape to which the dimension should
    161             be reshaped.
    162         axis (int): The index of the axis to be reshaped.
    163 
    164     Returns:
    165         A tensor of the same type as ``tensor`` with len(``shape``)-1
    166         additional dimensions, but the same number of elements.
    167     """
    168     msg = "0<= `axis` <= rank(tensor)-1"
    169     tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg)
    170     tf.debugging.assert_greater_equal(axis, 0, msg)
    171 
    172     s = tf.shape(tensor)
    173     new_shape = tf.concat([s[:axis], shape, s[axis+1:]], 0)
    174     output = tf.reshape(tensor, new_shape)
    175 
    176     return output
    177 
    178 def matrix_sqrt(tensor):
    179     r""" Computes the square root of a matrix.
    180 
    181     Given a batch of Hermitian positive semi-definite matrices
    182     :math:`\mathbf{A}`, returns matrices :math:`\mathbf{B}`,
    183     such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`.
    184 
    185     The two inner dimensions are assumed to correspond to the matrix rows
    186     and columns, respectively.
    187 
    188     Args:
    189         tensor ([..., M, M]) : A tensor of rank greater than or equal
    190             to two.
    191 
    192     Returns:
    193         A tensor of the same shape and type as ``tensor`` containing
    194         the matrix square root of its last two dimensions.
    195 
    196     Note:
    197         If you want to use this function in Graph mode with XLA, i.e., within
    198         a function that is decorated with ``@tf.function(jit_compile=True)``,
    199         you must set ``sionna.config.xla_compat=true``.
    200         See :py:attr:`~sionna.config.xla_compat`.
    201     """
    202     if sn.config.xla_compat and not tf.executing_eagerly():
    203         s, u = tf.linalg.eigh(tensor)
    204 
    205         # Compute sqrt of eigenvalues
    206         s = tf.abs(s)
    207         s = tf.sqrt(s)
    208         s = tf.cast(s, u.dtype)
    209 
    210         # Matrix multiplication
    211         s = tf.expand_dims(s, -2)
    212         return tf.matmul(u*s, u, adjoint_b=True)
    213     else:
    214         return tf.linalg.sqrtm(tensor)
    215 
    216 def matrix_sqrt_inv(tensor):
    217     r""" Computes the inverse square root of a Hermitian matrix.
    218 
    219     Given a batch of Hermitian positive definite matrices
    220     :math:`\mathbf{A}`, with square root matrices :math:`\mathbf{B}`,
    221     such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`, the function
    222     returns :math:`\mathbf{B}^{-1}`, such that
    223     :math:`\mathbf{B}^{-1}\mathbf{B}=\mathbf{I}`.
    224 
    225     The two inner dimensions are assumed to correspond to the matrix rows
    226     and columns, respectively.
    227 
    228     Args:
    229         tensor ([..., M, M]) : A tensor of rank greater than or equal
    230             to two.
    231 
    232     Returns:
    233         A tensor of the same shape and type as ``tensor`` containing
    234         the inverse matrix square root of its last two dimensions.
    235 
    236     Note:
    237         If you want to use this function in Graph mode with XLA, i.e., within
    238         a function that is decorated with ``@tf.function(jit_compile=True)``,
    239         you must set ``sionna.Config.xla_compat=true``.
    240         See :py:attr:`~sionna.Config.xla_compat`.
    241     """
    242     if sn.config.xla_compat and not tf.executing_eagerly():
    243         s, u = tf.linalg.eigh(tensor)
    244 
    245         # Compute 1/sqrt of eigenvalues
    246         s = tf.abs(s)
    247         tf.debugging.assert_positive(s, "Input must be positive definite.")
    248         s = 1/tf.sqrt(s)
    249         s = tf.cast(s, u.dtype)
    250 
    251         # Matrix multiplication
    252         s = tf.expand_dims(s, -2)
    253         return tf.matmul(u*s, u, adjoint_b=True)
    254     else:
    255         return tf.linalg.inv(tf.linalg.sqrtm(tensor))
    256 
    257 def matrix_inv(tensor):
    258     r""" Computes the inverse of a Hermitian matrix.
    259 
    260     Given a batch of Hermitian positive definite matrices
    261     :math:`\mathbf{A}`, the function
    262     returns :math:`\mathbf{A}^{-1}`, such that
    263     :math:`\mathbf{A}^{-1}\mathbf{A}=\mathbf{I}`.
    264 
    265     The two inner dimensions are assumed to correspond to the matrix rows
    266     and columns, respectively.
    267 
    268     Args:
    269         tensor ([..., M, M]) : A tensor of rank greater than or equal
    270             to two.
    271 
    272     Returns:
    273         A tensor of the same shape and type as ``tensor``, containing
    274         the inverse of its last two dimensions.
    275 
    276     Note:
    277         If you want to use this function in Graph mode with XLA, i.e., within
    278         a function that is decorated with ``@tf.function(jit_compile=True)``,
    279         you must set ``sionna.Config.xla_compat=true``.
    280         See :py:attr:`~sionna.Config.xla_compat`.
    281     """
    282     if tensor.dtype in [tf.complex64, tf.complex128] \
    283                     and sn.config.xla_compat \
    284                     and not tf.executing_eagerly():
    285         s, u = tf.linalg.eigh(tensor)
    286 
    287         # Compute inverse of eigenvalues
    288         s = tf.abs(s)
    289         tf.debugging.assert_positive(s, "Input must be positive definite.")
    290         s = 1/s
    291         s = tf.cast(s, u.dtype)
    292 
    293         # Matrix multiplication
    294         s = tf.expand_dims(s, -2)
    295         return tf.matmul(u*s, u, adjoint_b=True)
    296     else:
    297         return tf.linalg.inv(tensor)
    298 
    299 def matrix_pinv(tensor):
    300     r""" Computes the Moore–Penrose (or pseudo) inverse of a matrix.
    301 
    302     Given a batch of :math:`M \times K` matrices :math:`\mathbf{A}` with rank
    303     :math:`K` (i.e., linearly independent columns), the function returns
    304     :math:`\mathbf{A}^+`, such that
    305     :math:`\mathbf{A}^{+}\mathbf{A}=\mathbf{I}_K`.
    306 
    307     The two inner dimensions are assumed to correspond to the matrix rows
    308     and columns, respectively.
    309 
    310     Args:
    311         tensor ([..., M, K]) : A tensor of rank greater than or equal
    312             to two.
    313 
    314     Returns:
    315         A tensor of shape ([..., K,K]) of the same type as ``tensor``,
    316         containing the pseudo inverse of its last two dimensions.
    317 
    318     Note:
    319         If you want to use this function in Graph mode with XLA, i.e., within
    320         a function that is decorated with ``@tf.function(jit_compile=True)``,
    321         you must set ``sionna.config.xla_compat=true``.
    322         See :py:attr:`~sionna.config.xla_compat`.
    323     """
    324     inv = matrix_inv(tf.matmul(tensor, tensor, adjoint_a=True))
    325     return tf.matmul(inv, tensor, adjoint_b=True)