mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

ccholesky.c (7462B)


      1 #include "../include/common.h"
      2 
      3 #include <stddef.h>
      4 
      5 /** Complex Cholesky decomposition of a Hermitian positive-definite matrix G
      6  *
      7  * LL (floating point solution): 
      8  * G = L*L^H
      9  * L_ij = (G_ij - \sum_{k=0}^{j-1} L_ik L_jk^*) / L_jj
     10  * L_ii = sqrt(G_ii - \sum_{k=0}^{i-1} L_ik L_ik^*)
     11  * 
     12  * LDL (fixed point solution):
     13  * G = L*D*L^H
     14  * L_ij = (G_ij - \sum_{k=0}^{j-1} L_ik D_k L_jk^*) / D_j
     15  * D_i = G_ii - \sum_{k=0}^{i-1} L_ik D_k L_ik^*
     16  *
     17  * \global g_G matrix. Shape [NUM_TX][NUM_TX][NUM_SC]
     18  * \global g_L output lower triangular matrix. Shape [NUM_TX][NUM_TX][NUM_SC]
     19  * \global g_D output diagonal matrix (if DATA_TYPE_fixed is defined). Shape [NUM_TX][NUM_SC]
     20  */
     21 void ccholesky()
     22 {
     23   size_t i, j, k, s;
     24   size_t off_ij, off_jj, off_ii;
     25   size_t off_ik, off_jk;
     26   size_t off_i, off_j, off_k;
     27 
     28 #if defined(ARCH_x86) || defined(ARCH_rv)
     29   data_t tmp; /* Temporary variable for sqrt */
     30   acc_t sum_re, sum_im;
     31 
     32   for (i = 0; i < NUM_TX; ++i) {
     33     for (j = 0; j <= i; ++j) {
     34       for (s = 0; s < NUM_SC; ++s) {
     35         off_ij = i * NUM_TX * NUM_SC + j * NUM_SC + s;
     36         sum_im = sum_re = 0;
     37 
     38         /* Calculate the sum */
     39         for (k = 0; k < j; ++k) {
     40           off_ik = i * NUM_TX * NUM_SC + k * NUM_SC + s;
     41           off_jk = j * NUM_TX * NUM_SC + k * NUM_SC + s;
     42 #if defined(DATA_TYPE_float)
     43           sum_re += g_L.re[off_ik] * g_L.re[off_jk]
     44                   - g_L.im[off_ik] * g_L.im[off_jk];
     45 #elif defined(DATA_TYPE_fixed)
     46           sum_re += (g_L.re[off_ik] * g_L.re[off_jk]
     47                   - g_L.im[off_ik] * g_L.im[off_jk])
     48                   * g_D[k * NUM_SC + s];
     49 #else
     50 #error "Unknown data type"
     51 #endif
     52           sum_im += g_L.re[off_ik] * g_L.im[off_jk]
     53                   + g_L.im[off_ik] * g_L.re[off_jk];
     54         }
     55 
     56         if (i == j) {
     57           off_ii = i * NUM_TX * NUM_SC + i * NUM_SC + s;
     58 #if defined(DATA_TYPE_float)
     59 
     60 #if defined(ARCH_x86)
     61           __asm__ volatile (
     62             "flds %1\n"
     63             "fsubs %2\n"
     64             "fsqrt\n"
     65             "fstps %0\n"
     66             : "=m" (g_L.re[off_ii])
     67             : "m" (g_G.re[off_ij]), "m" (sum_re)
     68           );
     69 #elif defined(ARCH_rv)
     70           __asm__ volatile (
     71             "fsub.s %0, %1, %2\n"   /* tmp = g_G.re[off_ij] - sum_re */
     72             "fsqrt.s %0, %0\n"      /* tmp = sqrtf(tmp) */
     73             : "=&f"(tmp) : "f"(g_G.re[off_ij]), "f"(sum_re)
     74           );
     75           g_L.re[off_ii] = tmp;
     76           g_L.im[off_ii] = 0;
     77 #else
     78 #error "Unknown architecture"
     79 #endif
     80 
     81 #elif defined(DATA_TYPE_fixed)
     82           /* Calculate D_i = G_ii - sum */
     83           g_D[i * NUM_SC + s] = g_G.re[off_ii] - (data_t)(sum_re >> FP_Q);
     84 #else
     85 #error "Unknown data type"
     86 #endif
     87 
     88         } else { /* i != j */
     89 #if defined(DATA_TYPE_float)
     90           /* Calculate L_ij = (G_ij - sum) / L_jj */
     91           off_jj = j * NUM_TX * NUM_SC + j * NUM_SC + s;
     92           g_L.re[off_ij] = (g_G.re[off_ij] - sum_re) / g_L.re[off_jj];
     93           g_L.im[off_ij] = (g_G.im[off_ij] - sum_im) / g_L.re[off_jj];
     94 #elif defined(DATA_TYPE_fixed)
     95           /* Calculate L_ij = (G_ij - sum) / D_j */
     96           off_j = j * NUM_SC + s;
     97           sum_re = ((acc_t)g_G.re[off_ij] << FP_Q) - sum_re;
     98           g_L.re[off_ij] = (data_t)(sum_re / (acc_t)g_D[off_j]);
     99           sum_im = ((acc_t)g_G.im[off_ij] << FP_Q) - sum_im;
    100           g_L.im[off_ij] = (data_t)(sum_im / (acc_t)g_D[off_j]);
    101 #else
    102 #error "Unknown data type"
    103 #endif
    104         }
    105       }
    106     }
    107   }
    108 #elif defined(ARCH_rvv)
    109   size_t sz, vl;
    110 
    111   for (i = 0; i < NUM_TX; ++i)
    112     for (j = 0; j <= i; ++j) {
    113       sz = NUM_SC;
    114       s = 0;
    115 
    116       while (sz > 0) {
    117         /* Calculate
    118          * sum_{k=0}^{j-1} L_ik L_jk^* or
    119          * sum_{k=0}^{j-1} L_ik D_k L_jk^*
    120          * 
    121          * v0 - sum real part
    122          * v4 - sum imaginary part */
    123         __asm__ volatile(
    124           "vsetvli %0, %1, e32, m4, ta, ma\n"
    125           "vmv.v.i v0, 0\n"
    126           "vmv.v.i v4, 0\n"
    127           : "=r"(vl) : "r"(sz)
    128         );
    129 
    130         for (k = 0; k < j; ++k) {
    131           off_ik = i * NUM_TX * NUM_SC + k * NUM_SC + s;
    132           off_jk = j * NUM_TX * NUM_SC + k * NUM_SC + s;
    133           __asm__ volatile(
    134             "vle32.v v8, (%0)\n"
    135             "vle32.v v12, (%1)\n"
    136             "vle32.v v16, (%2)\n"
    137             "vle32.v v20, (%3)\n"
    138 #if defined(DATA_TYPE_float)
    139             /* real part */
    140             "vfmacc.vv v0, v8, v16\n"
    141             "vfmacc.vv v0, v12, v20\n"
    142             /* imaginary part */
    143             "vfmacc.vv v4, v12, v16\n"
    144             "vfnmsac.vv v4, v8, v20\n"
    145 #elif defined(DATA_TYPE_fixed)
    146             "vle32.v v24, (%4)\n"
    147             /* real part */
    148             "vsmul.vv v28, v8, v24\n"
    149             "vsmul.vv v28, v28, v16\n"
    150             "vsadd.vv v0, v0, v28\n"
    151             "vsmul.vv v28, v12, v24\n"
    152             "vsmul.vv v28, v28, v20\n"
    153             "vsadd.vv v0, v0, v28\n"
    154             /* imaginary part */
    155             "vsmul.vv v28, v12, v16\n"
    156             "vsmul.vv v28, v28, v20\n"
    157             "vsadd.vv v4, v4, v28\n"
    158             "vsmul.vv v28, v8, v24\n"
    159             "vsmul.vv v28, v28, v20\n"
    160             "vssub.vv v4, v4, v28\n"
    161 #else
    162 #error "Unknown data type"
    163 #endif
    164           :
    165           : "r"(&g_L.re[off_ik]), "r"(&g_L.im[off_ik]),
    166             "r"(&g_L.re[off_jk]), "r"(&g_L.im[off_jk])
    167 #if defined(DATA_TYPE_fixed)
    168             , "r"(&g_D[k * NUM_SC + s])
    169 #endif
    170           );
    171         }
    172 
    173         if (i == j) {
    174           off_ii = i * NUM_TX * NUM_SC + i * NUM_SC + s;
    175           /* L_ii = sqrt(G_ii - sum) or D_i = (G_ii - sum) */
    176           /* G_ii imaginary part is 0, so we can ignore it */
    177           __asm__ volatile(
    178             "vle32.v v8, (%0)\n"
    179 #if defined(DATA_TYPE_float)
    180             "vfsub.vv v0, v8, v0\n"
    181             "vfsqrt.v v0, v0\n"
    182             "vfneg.v v4, v4\n"
    183 #elif defined(DATA_TYPE_fixed)
    184             "vssub.vv v0, v8, v0\n"
    185             "vneg.v v4, v4\n"
    186 #else
    187 #error "Unknown data type"
    188 #endif
    189             "vse32.v v0, (%1)\n"
    190             : : "r"(&g_G.re[off_ii]),
    191 #if defined(DATA_TYPE_float)
    192             "r"(&g_L.re[off_ii])
    193 #elif defined(DATA_TYPE_fixed)
    194             "r"(&g_D[i * NUM_SC + s])
    195 #else
    196 #error "Unknown data type"
    197 #endif
    198           );
    199         } else { /* i != j */
    200           off_ij = i * NUM_TX * NUM_SC + j * NUM_SC + s;
    201           off_jj = j * NUM_TX * NUM_SC + j * NUM_SC + s;
    202           off_j = j * NUM_SC + s;
    203           /* Calculate L_ij = (G_ij - sum) / L_jj or L_ij = (G_ij - sum) / D_j */
    204           /* L_jj is always real, so we can ignore the imaginary part */
    205           __asm__ volatile(
    206             "vle32.v v8, (%0)\n"  /* G_ij.re */
    207             "vle32.v v12, (%1)\n" /* G_ij.im */
    208             "vle32.v v16, (%2)\n" /* L_jj or D_j */
    209 #if defined(DATA_TYPE_float)
    210             "vfsub.vv v0, v8, v0\n"
    211             "vfsub.vv v4, v12, v4\n"
    212             "vfdiv.vv v0, v0, v16\n"
    213             "vfdiv.vv v4, v4, v16\n"
    214 #elif defined(DATA_TYPE_fixed)
    215             "vssub.vv v0, v8, v0\n"
    216             "vssub.vv v4, v12, v4\n"
    217             "vdiv.vv v0, v0, v16\n"
    218             "vdiv.vv v4, v4, v16\n"
    219             "vsll.vi v0, v0, 1\n"
    220             "vsll.vi v4, v4, 1\n"
    221 #else
    222 #error "Unknown data type"
    223 #endif
    224             /* Store result */
    225             "vse32.v v0, (%3)\n"
    226             "vse32.v v4, (%4)\n"
    227             :
    228             : "r"(&g_G.re[off_ij]), "r"(&g_G.im[off_ij]),
    229               "r"(&g_L.re[off_jj]),
    230               "r"(&g_L.re[off_ij]), "r"(&g_L.im[off_ij])
    231           );
    232         }
    233 
    234         sz -= vl;
    235         s += vl;
    236       }
    237     }
    238 #else
    239 #error "Unknown architecture"
    240 #endif
    241 }