mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

cbackwardsub.c (4542B)


      1 #include "../include/common.h"
      2 
      3 #include <stddef.h>
      4 
      5 /** Complex backward substitution L^H*x_MMSE = z
      6  * 
      7  * x_{MMSE}_t = (z_t - \sum_{tt=t+1}^{NUM_TX-1} L_{tt t} x_{MMSE}_{tt}) / L_{t t} (for LL / float solution)
      8  * x_{MMSE}_t = (z_t / D_t - \sum_{tt=t+1}^{NUM_TX-1} L_{tt t} x_{MMSE}_{tt}) (for LDL / fixed solution)
      9  * 
     10  * \global g_L lower triangular matrix. Shape [NUM_TX][NUM_TX][NUM_SC]
     11  * \global g_D diagonal matrix (only if DATA_TYPE_fixed is defined). Shape [NUM_TX][NUM_SC]
     12  * \global g_z rhs vector. Shape [NUM_TX][NUM_SC]
     13  * \global g_x_MMSE output vector. Shape [NUM_TX][NUM_SC]
     14  */
     15 void cbackwardsub()
     16 {
     17   size_t t, tt, s;
     18   size_t off_L, off_z, off_x_MMSE, off_D;
     19 
     20 #if defined(ARCH_x86) || defined(ARCH_rv)
     21   acc_t sum_re, sum_im;
     22 
     23   for (t = NUM_TX - 1; t != (size_t)-1; --t) {
     24     for (s = 0; s < NUM_SC; ++s) {
     25       sum_re = sum_im = 0;
     26       for (tt = t + 1; tt < NUM_TX; ++tt) {
     27         off_L = tt * NUM_TX * NUM_SC + t * NUM_SC + s;
     28         off_x_MMSE = tt * NUM_SC + s;
     29         sum_re += (acc_t)g_L.re[off_L] * (acc_t)g_x_MMSE.re[off_x_MMSE]
     30                 - (acc_t)g_L.im[off_L] * (acc_t)g_x_MMSE.im[off_x_MMSE];
     31         sum_im += (acc_t)g_L.re[off_L] * (acc_t)g_x_MMSE.im[off_x_MMSE]
     32                 + (acc_t)g_L.im[off_L] * (acc_t)g_x_MMSE.re[off_x_MMSE];
     33       }
     34       off_z = t * NUM_SC + s;
     35       off_x_MMSE = t * NUM_SC + s;
     36 #if defined(DATA_TYPE_float)
     37       off_L = t * NUM_TX * NUM_SC + t * NUM_SC + s;
     38       g_x_MMSE.re[off_x_MMSE] = (g_z.re[off_z] - (data_t)sum_re) / g_L.re[off_L];
     39       g_x_MMSE.im[off_x_MMSE] = (g_z.im[off_z] - (data_t)sum_im) / g_L.re[off_L];
     40 #elif defined(DATA_TYPE_fixed)
     41       off_D = t * NUM_SC + s;
     42       g_x_MMSE.re[off_x_MMSE] = (data_t)((acc_t)(g_z.re[off_z] << FP_Q) / g_D[off_D])
     43                               - (data_t)(sum_re >> FP_Q);
     44       g_x_MMSE.im[off_x_MMSE] = (data_t)((acc_t)(g_z.im[off_z] << FP_Q) / g_D[off_D])
     45                               - (data_t)(sum_im >> FP_Q);
     46 #else
     47 #error "Unknown data type"
     48 #endif
     49     }
     50   }
     51 
     52 #elif defined(ARCH_rvv)
     53   size_t sz, vl;
     54 
     55   for (t = NUM_TX - 1; t != (size_t)-1; --t) {
     56     sz = NUM_SC;
     57     s = 0;
     58 
     59     while (sz > 0){
     60       /* Initialize x_MMSE as z */
     61       /* v0 - x_MMSE real part */
     62       /* v4 - x_MMSE imaginary part */
     63       off_z = t * NUM_SC + s;
     64       __asm__ volatile(
     65         "vsetvli %0, %1, e32, m4, ta, ma\n"
     66         "vle32.v v0, (%2)\n"
     67         "vle32.v v4, (%3)\n"
     68         : "=r"(vl)
     69         : "r"(sz),
     70           "r"(&g_z.re[off_z]), "r"(&g_z.im[off_z])
     71       );
     72 
     73 #if defined(DATA_TYPE_fixed)
     74       /* Divide by D_t */
     75       __asm__ volatile (
     76         "vle32.v v8, (%0)\n"
     77         "vfdiv.vv v0, v0, v8\n"
     78         "vfdiv.vv v4, v4, v8\n"
     79         :
     80         : "r"(&g_D[t * NUM_SC + s])
     81       );
     82 #endif
     83 
     84       for (tt = tt + 1; tt < NUM_TX; ++tt) {
     85         /* z - sum L_{tt t} * x_{MMSE}_{tt} or
     86          * z / D_t - sum L_{tt t} * x_{MMSE}_{tt} */
     87         off_L = tt * NUM_TX * NUM_SC + t * NUM_SC + s;
     88         off_x_MMSE = tt * NUM_SC + s;
     89         __asm__ volatile(
     90           "vle32.v v8, (%0)\n"
     91           "vle32.v v12, (%1)\n"
     92           "vle32.v v16, (%2)\n"
     93           "vle32.v v20, (%3)\n"
     94 #if defined(DATA_TYPE_float)
     95           /* real part */
     96           "vfnmsac.vv v0, v8, v16\n"
     97           "vfmacc.vv v0, v12, v20\n"
     98           /* imaginary part */
     99           "vfnmsac.vv v4, v12, v16\n"
    100           "vfnmsac.vv v4, v8, v20\n"
    101 #elif defined(DATA_TYPE_fixed)
    102           /* real part */
    103           "vsmul.vv v24, v8, v16\n"
    104           "vssub.vv v0, v0, v24\n"
    105           "vsmul.vv v24, v12, v20\n"
    106           "vsadd.vv v0, v0, v24\n"
    107           /* imaginary part */
    108           "vsmul.vv v24, v12, v16\n"
    109           "vssub.vv v4, v4, v24\n"
    110           "vsmul.vv v24, v8, v20\n"
    111           "vssub.vv v4, v4, v24\n"
    112 #else
    113 #error "Unknown data type"
    114 #endif
    115           :
    116           : "r"(&g_L.re[off_L]), "r"(&g_L.im[off_L]),
    117             "r"(&g_x_MMSE.re[off_x_MMSE]), "r"(&g_x_MMSE.im[off_x_MMSE])
    118         );
    119       }
    120 
    121 #if defined(DATA_TYPE_float)
    122       /* Divide by L_ii */
    123       off_L = t * NUM_TX * NUM_SC + t * NUM_SC + s;
    124       __asm__ volatile (
    125         "vle32.v v8, (%0)\n"
    126         "vfdiv.vv v0, v0, v8\n"
    127         "vfdiv.vv v4, v4, v8\n"
    128         :
    129         : "r"(&g_L.re[off_L])
    130       );
    131 #endif
    132 
    133       /* Store result */
    134       off_x_MMSE = t * NUM_SC + s;
    135       __asm__ volatile(
    136         "vse32.v v0, (%0)\n"
    137         "vse32.v v4, (%1)\n"
    138         :
    139         : "r"(&g_x_MMSE.re[off_x_MMSE]), "r"(&g_x_MMSE.im[off_x_MMSE])
    140       );
    141 
    142       sz -= vl;
    143       s += vl;
    144     }
    145   }
    146 
    147 #else
    148 #error "Unknown architecture"
    149 #endif
    150 }