tensors.py (11798B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Functions extending TensorFlow tensor operations""" 6 7 import tensorflow as tf 8 import sionna as sn 9 10 def expand_to_rank(tensor, target_rank, axis=-1): 11 """Inserts as many axes to a tensor as needed to achieve a desired rank. 12 13 This operation inserts additional dimensions to a ``tensor`` starting at 14 ``axis``, so that so that the rank of the resulting tensor has rank 15 ``target_rank``. The dimension index follows Python indexing rules, i.e., 16 zero-based, where a negative index is counted backward from the end. 17 18 Args: 19 tensor : A tensor. 20 target_rank (int) : The rank of the output tensor. 21 If ``target_rank`` is smaller than the rank of ``tensor``, 22 the function does nothing. 23 axis (int) : The dimension index at which to expand the 24 shape of ``tensor``. Given a ``tensor`` of `D` dimensions, 25 ``axis`` must be within the range `[-(D+1), D]` (inclusive). 26 27 Returns: 28 A tensor with the same data as ``tensor``, with 29 ``target_rank``- rank(``tensor``) additional dimensions inserted at the 30 index specified by ``axis``. 31 If ``target_rank`` <= rank(``tensor``), ``tensor`` is returned. 32 """ 33 num_dims = tf.maximum(target_rank - tf.rank(tensor), 0) 34 output = insert_dims(tensor, num_dims, axis) 35 36 return output 37 38 def flatten_dims(tensor, num_dims, axis): 39 """ 40 Flattens a specified set of dimensions of a tensor. 41 42 This operation flattens ``num_dims`` dimensions of a ``tensor`` 43 starting at a given ``axis``. 44 45 Args: 46 tensor : A tensor. 47 num_dims (int): The number of dimensions 48 to combine. Must be larger than two and less or equal than the 49 rank of ``tensor``. 50 axis (int): The index of the dimension from which to start. 51 52 Returns: 53 A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser 54 dimensions, but the same number of elements. 55 """ 56 msg = "`num_dims` must be >= 2" 57 tf.debugging.assert_greater_equal(num_dims, 2, msg) 58 59 msg = "`num_dims` must <= rank(`tensor`)" 60 tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg) 61 62 msg = "0<= `axis` <= rank(tensor)-1" 63 tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg) 64 tf.debugging.assert_greater_equal(axis, 0, msg) 65 66 msg ="`num_dims`+`axis` <= rank(`tensor`)" 67 tf.debugging.assert_less_equal(num_dims + axis, tf.rank(tensor), msg) 68 69 if num_dims==len(tensor.shape): 70 new_shape = [-1] 71 elif axis==0: 72 shape = tf.shape(tensor) 73 new_shape = tf.concat([[-1], shape[axis+num_dims:]], 0) 74 else: 75 shape = tf.shape(tensor) 76 flat_dim = tf.reduce_prod(tensor.shape[axis:axis+num_dims]) 77 new_shape = tf.concat([shape[:axis], 78 [flat_dim], 79 shape[axis+num_dims:]], 0) 80 81 return tf.reshape(tensor, new_shape) 82 83 def flatten_last_dims(tensor, num_dims=2): 84 """ 85 Flattens the last `n` dimensions of a tensor. 86 87 This operation flattens the last ``num_dims`` dimensions of a ``tensor``. 88 It is a simplified version of the function ``flatten_dims``. 89 90 Args: 91 tensor : A tensor. 92 num_dims (int): The number of dimensions 93 to combine. Must be greater than or equal to two and less or equal 94 than the rank of ``tensor``. 95 96 Returns: 97 A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser 98 dimensions, but the same number of elements. 99 """ 100 msg = "`num_dims` must be >= 2" 101 tf.debugging.assert_greater_equal(num_dims, 2, msg) 102 103 msg = "`num_dims` must <= rank(`tensor`)" 104 tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg) 105 106 if num_dims==len(tensor.shape): 107 new_shape = [-1] 108 else: 109 shape = tf.shape(tensor) 110 last_dim = tf.reduce_prod(tensor.shape[-num_dims:]) 111 new_shape = tf.concat([shape[:-num_dims], [last_dim]], 0) 112 113 return tf.reshape(tensor, new_shape) 114 115 def insert_dims(tensor, num_dims, axis=-1): 116 """Adds multiple length-one dimensions to a tensor. 117 118 This operation is an extension to TensorFlow`s ``expand_dims`` function. 119 It inserts ``num_dims`` dimensions of length one starting from the 120 dimension ``axis`` of a ``tensor``. The dimension 121 index follows Python indexing rules, i.e., zero-based, where a negative 122 index is counted backward from the end. 123 124 Args: 125 tensor : A tensor. 126 num_dims (int) : The number of dimensions to add. 127 axis : The dimension index at which to expand the 128 shape of ``tensor``. Given a ``tensor`` of `D` dimensions, 129 ``axis`` must be within the range `[-(D+1), D]` (inclusive). 130 131 Returns: 132 A tensor with the same data as ``tensor``, with ``num_dims`` additional 133 dimensions inserted at the index specified by ``axis``. 134 """ 135 msg = "`num_dims` must be nonnegative." 136 tf.debugging.assert_greater_equal(num_dims, 0, msg) 137 138 rank = tf.rank(tensor) 139 msg = "`axis` is out of range `[-(D+1), D]`)" 140 tf.debugging.assert_less_equal(axis, rank, msg) 141 tf.debugging.assert_greater_equal(axis, -(rank+1), msg) 142 143 axis = axis if axis>=0 else rank+axis+1 144 shape = tf.shape(tensor) 145 new_shape = tf.concat([shape[:axis], 146 tf.ones([num_dims], tf.int32), 147 shape[axis:]], 0) 148 output = tf.reshape(tensor, new_shape) 149 150 return output 151 152 def split_dim(tensor, shape, axis): 153 """Reshapes a dimension of a tensor into multiple dimensions. 154 155 This operation splits the dimension ``axis`` of a ``tensor`` into 156 multiple dimensions according to ``shape``. 157 158 Args: 159 tensor : A tensor. 160 shape (list or TensorShape): The shape to which the dimension should 161 be reshaped. 162 axis (int): The index of the axis to be reshaped. 163 164 Returns: 165 A tensor of the same type as ``tensor`` with len(``shape``)-1 166 additional dimensions, but the same number of elements. 167 """ 168 msg = "0<= `axis` <= rank(tensor)-1" 169 tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg) 170 tf.debugging.assert_greater_equal(axis, 0, msg) 171 172 s = tf.shape(tensor) 173 new_shape = tf.concat([s[:axis], shape, s[axis+1:]], 0) 174 output = tf.reshape(tensor, new_shape) 175 176 return output 177 178 def matrix_sqrt(tensor): 179 r""" Computes the square root of a matrix. 180 181 Given a batch of Hermitian positive semi-definite matrices 182 :math:`\mathbf{A}`, returns matrices :math:`\mathbf{B}`, 183 such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`. 184 185 The two inner dimensions are assumed to correspond to the matrix rows 186 and columns, respectively. 187 188 Args: 189 tensor ([..., M, M]) : A tensor of rank greater than or equal 190 to two. 191 192 Returns: 193 A tensor of the same shape and type as ``tensor`` containing 194 the matrix square root of its last two dimensions. 195 196 Note: 197 If you want to use this function in Graph mode with XLA, i.e., within 198 a function that is decorated with ``@tf.function(jit_compile=True)``, 199 you must set ``sionna.config.xla_compat=true``. 200 See :py:attr:`~sionna.config.xla_compat`. 201 """ 202 if sn.config.xla_compat and not tf.executing_eagerly(): 203 s, u = tf.linalg.eigh(tensor) 204 205 # Compute sqrt of eigenvalues 206 s = tf.abs(s) 207 s = tf.sqrt(s) 208 s = tf.cast(s, u.dtype) 209 210 # Matrix multiplication 211 s = tf.expand_dims(s, -2) 212 return tf.matmul(u*s, u, adjoint_b=True) 213 else: 214 return tf.linalg.sqrtm(tensor) 215 216 def matrix_sqrt_inv(tensor): 217 r""" Computes the inverse square root of a Hermitian matrix. 218 219 Given a batch of Hermitian positive definite matrices 220 :math:`\mathbf{A}`, with square root matrices :math:`\mathbf{B}`, 221 such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`, the function 222 returns :math:`\mathbf{B}^{-1}`, such that 223 :math:`\mathbf{B}^{-1}\mathbf{B}=\mathbf{I}`. 224 225 The two inner dimensions are assumed to correspond to the matrix rows 226 and columns, respectively. 227 228 Args: 229 tensor ([..., M, M]) : A tensor of rank greater than or equal 230 to two. 231 232 Returns: 233 A tensor of the same shape and type as ``tensor`` containing 234 the inverse matrix square root of its last two dimensions. 235 236 Note: 237 If you want to use this function in Graph mode with XLA, i.e., within 238 a function that is decorated with ``@tf.function(jit_compile=True)``, 239 you must set ``sionna.Config.xla_compat=true``. 240 See :py:attr:`~sionna.Config.xla_compat`. 241 """ 242 if sn.config.xla_compat and not tf.executing_eagerly(): 243 s, u = tf.linalg.eigh(tensor) 244 245 # Compute 1/sqrt of eigenvalues 246 s = tf.abs(s) 247 tf.debugging.assert_positive(s, "Input must be positive definite.") 248 s = 1/tf.sqrt(s) 249 s = tf.cast(s, u.dtype) 250 251 # Matrix multiplication 252 s = tf.expand_dims(s, -2) 253 return tf.matmul(u*s, u, adjoint_b=True) 254 else: 255 return tf.linalg.inv(tf.linalg.sqrtm(tensor)) 256 257 def matrix_inv(tensor): 258 r""" Computes the inverse of a Hermitian matrix. 259 260 Given a batch of Hermitian positive definite matrices 261 :math:`\mathbf{A}`, the function 262 returns :math:`\mathbf{A}^{-1}`, such that 263 :math:`\mathbf{A}^{-1}\mathbf{A}=\mathbf{I}`. 264 265 The two inner dimensions are assumed to correspond to the matrix rows 266 and columns, respectively. 267 268 Args: 269 tensor ([..., M, M]) : A tensor of rank greater than or equal 270 to two. 271 272 Returns: 273 A tensor of the same shape and type as ``tensor``, containing 274 the inverse of its last two dimensions. 275 276 Note: 277 If you want to use this function in Graph mode with XLA, i.e., within 278 a function that is decorated with ``@tf.function(jit_compile=True)``, 279 you must set ``sionna.Config.xla_compat=true``. 280 See :py:attr:`~sionna.Config.xla_compat`. 281 """ 282 if tensor.dtype in [tf.complex64, tf.complex128] \ 283 and sn.config.xla_compat \ 284 and not tf.executing_eagerly(): 285 s, u = tf.linalg.eigh(tensor) 286 287 # Compute inverse of eigenvalues 288 s = tf.abs(s) 289 tf.debugging.assert_positive(s, "Input must be positive definite.") 290 s = 1/s 291 s = tf.cast(s, u.dtype) 292 293 # Matrix multiplication 294 s = tf.expand_dims(s, -2) 295 return tf.matmul(u*s, u, adjoint_b=True) 296 else: 297 return tf.linalg.inv(tensor) 298 299 def matrix_pinv(tensor): 300 r""" Computes the Moore–Penrose (or pseudo) inverse of a matrix. 301 302 Given a batch of :math:`M \times K` matrices :math:`\mathbf{A}` with rank 303 :math:`K` (i.e., linearly independent columns), the function returns 304 :math:`\mathbf{A}^+`, such that 305 :math:`\mathbf{A}^{+}\mathbf{A}=\mathbf{I}_K`. 306 307 The two inner dimensions are assumed to correspond to the matrix rows 308 and columns, respectively. 309 310 Args: 311 tensor ([..., M, K]) : A tensor of rank greater than or equal 312 to two. 313 314 Returns: 315 A tensor of shape ([..., K,K]) of the same type as ``tensor``, 316 containing the pseudo inverse of its last two dimensions. 317 318 Note: 319 If you want to use this function in Graph mode with XLA, i.e., within 320 a function that is decorated with ``@tf.function(jit_compile=True)``, 321 you must set ``sionna.config.xla_compat=true``. 322 See :py:attr:`~sionna.config.xla_compat`. 323 """ 324 inv = matrix_inv(tf.matmul(tensor, tensor, adjoint_a=True)) 325 return tf.matmul(inv, tensor, adjoint_b=True)