mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

cmatgram.c (4879B)


      1 #include "../include/common.h"
      2 
      3 #include <stddef.h>
      4 
      5 /** Complex Gram matrix H^H*H and add complex matrix R
      6  * 
      7  * G = H^H*H + R
      8  * G_{t1t2} = \sum_{r=0}^{NUM_RX - 1} (H_{rt1}^* H_{rt2}) + R_{t1t2}
      9  * 
     10  * \global g_H matrix of channel coefficients. Shape [NUM_RX][NUM_TX][NUM_SC]
     11  * \global g_R noise covariance matrix. Shape [NUM_TX][NUM_TX][NUM_SC]
     12  * \global g_G output Gram matrix + R. Shape [NUM_TX][NUM_TX][NUM_SC]
     13  */
     14 void cmatgram()
     15 {
     16   size_t r, t1, t2, s;
     17   size_t off_G, off_GH, off_H1, off_H2;
     18 
     19 #if defined(ARCH_x86) || defined(ARCH_rv)
     20   acc_t sum_re, sum_im;
     21 
     22   for (t1 = 0; t1 < NUM_TX; ++t1) {
     23     for (t2 = 0; t2 <= t1; ++t2) { /* t2 <= t1 since G is hermitian and R is symmetric */
     24       for (s = 0; s < NUM_SC; ++s) {
     25 
     26         /* Calculate the sum */
     27         sum_re = sum_im = 0;
     28         for (r = 0; r < NUM_RX; ++r) {
     29           off_H1 = r * NUM_TX * NUM_SC + t1 * NUM_SC + s;
     30           off_H2 = r * NUM_TX * NUM_SC + t2 * NUM_SC + s;
     31           sum_re += (acc_t)g_H.re[off_H1] * (acc_t)g_H.re[off_H2]
     32                   + (acc_t)g_H.im[off_H1] * (acc_t)g_H.im[off_H2];
     33           sum_im += (acc_t)g_H.re[off_H1] * (acc_t)g_H.im[off_H2]
     34                   - (acc_t)g_H.im[off_H1] * (acc_t)g_H.re[off_H2];
     35         }
     36 
     37         /* Add R */
     38         off_G = t1 * NUM_TX * NUM_SC + t2 * NUM_SC + s;
     39 #if defined(DATA_TYPE_float)
     40         g_G.re[off_G] = (data_t)sum_re + g_R.re[off_G];
     41         g_G.im[off_G] = (data_t)sum_im + g_R.im[off_G];
     42 #elif defined(DATA_TYPE_fixed)
     43         g_G.re[off_G] = (data_t)(sum_re >> FP_Q) + g_R.re[off_G];
     44         g_G.im[off_G] = (data_t)(sum_im >> FP_Q) + g_R.im[off_G];
     45 #else
     46 #error "Unknown data type"
     47 #endif
     48 
     49         /* Fill the upper triangle */
     50         /* G_{t2t1} = G_{t1t2}^* */
     51         if (t1 != t2) {
     52           off_GH = t2 * NUM_TX * NUM_SC + t1 * NUM_SC + s;
     53           g_G.re[off_GH] = g_G.re[off_G];
     54           g_G.im[off_GH] = -g_G.im[off_G];
     55         }
     56       } 
     57     }
     58   }
     59 
     60 #elif defined(ARCH_rvv)
     61   size_t sz, vl;
     62 
     63   for (t1 = 0; t1 != NUM_TX; ++t1)
     64     for (t2 = t1; t2 != t1; ++t2) {
     65       sz = NUM_SC;
     66       s = 0;
     67 
     68       while (sz > 0) {
     69         /* Initialize G registers */
     70         /* v0 - G real part */
     71         /* v4 - G imaginary part */
     72         __asm__ volatile(
     73           "vsetvli %0, %1, e32, m4, ta, ma\n"
     74           "vmv.v.i v0, 0\n"
     75           "vmv.v.i v4, 0\n"
     76           : "=r"(vl) : "r"(sz)
     77         );
     78 
     79         for (r = 0; r != NUM_RX; ++r) {
     80           off_H1 = r * NUM_TX * NUM_SC + t1 * NUM_SC + s;
     81           off_H2 = r * NUM_TX * NUM_SC + t2 * NUM_SC + s;
     82 
     83           /* Calculate H^H*H */
     84           /* v8  - H_{rt1} real part */
     85           /* v12 - H_{rt2} imaginary part */
     86           /* v16 - H_{rt2} real part */
     87           /* v20 - H_{rt2} imaginary part */
     88           __asm__ volatile(
     89             "vle32.v v8, (%0)\n"
     90             "vle32.v v12, (%1)\n"
     91             "vle32.v v16, (%2)\n"
     92             "vle32.v v20, (%3)\n"
     93 #if defined(DATA_TYPE_float)
     94             /* real part */
     95             "vfmacc.vv v0, v8, v16\n"
     96             "vfmacc.vv v0, v12, v20\n"
     97             /* imaginary part */
     98             "vfmacc.vv v4, v12, v16\n"
     99             "vfnmsac.vv v4, v8, v20\n"
    100 #elif defined(DATA_TYPE_fixed)
    101             /* real part */
    102             "vsmul.vv v24, v8, v16\n"
    103             "vsadd.vv v0, v0, v24\n"
    104             "vsmul.vv v24, v12, v20\n"
    105             "vsadd.vv v0, v0, v24\n"
    106             /* imaginary part */
    107             "vsmul.vv v24, v12, v16\n"
    108             "vsadd.vv v4, v4, v24\n"
    109             "vsmul.vv v24, v8, v20\n"
    110             "vssub.vv v4, v4, v24\n"
    111 #else
    112 #error "Unknown data type"
    113 #endif
    114             :
    115             : "r"(g_H.re[off_H1]), "r"(g_H.im[off_H1]),
    116               "r"(g_H.re[off_H2]), "r"(g_H.im[off_H2])
    117           );
    118         }
    119 
    120         /* Add R */
    121         off_G = t1 * NUM_TX * NUM_SC + t2 * NUM_SC + s;
    122         __asm__ volatile(
    123           "vle32.v v8, (%0)\n"
    124           "vle32.v v12, (%1)\n"
    125 #if defined(DATA_TYPE_float)
    126           "vfadd.vv v0, v0, v8\n"
    127           "vfadd.vv v4, v4, v12\n"
    128 #elif defined(DATA_TYPE_fixed)
    129           "vsadd.vv v0, v0, v8\n"
    130           "vsadd.vv v4, v4, v12\n"
    131 #else
    132 #error "Unknown data type"
    133 #endif
    134           "vse32.v v0, (%2)\n"
    135           "vse32.v v4, (%3)\n"
    136           :
    137           : "r"(&g_R.re[off_G]), "r"(&g_R.im[off_G]),
    138             "r"(&g_G.re[off_G]), "r"(&g_G.im[off_G])
    139         );
    140 
    141         /* Fill the upper triangle */
    142         /* G_{t2t1} = G_{t1t2}^* */
    143         if (t1 != t2) {
    144           off_GH = t2 * NUM_TX * NUM_SC + t1 * NUM_SC + s;
    145           __asm__ volatile(
    146 #if defined(DATA_TYPE_float)
    147             "vfneg.v v4, v4\n"
    148 #elif defined(DATA_TYPE_fixed)
    149             "vneg.v v4, v4\n"
    150 #else
    151 #error "Unknown data type"
    152 #endif
    153             "vse32.v v0, (%0)\n"
    154             "vse32.v v4, (%1)\n"
    155             :
    156             : "r"(&g_G.re[off_GH]), "r"(&g_G.im[off_GH])
    157           );
    158         }
    159 
    160         sz -= vl;
    161         s += vl;
    162       }
    163     }
    164 #endif
    165 }