mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

cmatvecmul.c (3022B)


      1 #include "../include/common.h"
      2 
      3 #include <stddef.h>
      4 
      5 /** Complex matrix-vector multiplication HHy = HH*y
      6  * 
      7  * HHy_t = \sum_{r=0}^{NUM_RX-1} H_{rt}^* * y_r
      8  *
      9  * \global g_H matrix. Shape [NUM_RX][NUM_TX][NUM_SC]
     10  * \global g_y vector. Shape [NUM_RX][NUM_SC]
     11  * \global g_HHy output vector. Shape [NUM_TX][NUM_SC]
     12  */
     13 void cmatvecmul()
     14 {
     15   size_t t, r, s;
     16   size_t off_H, off_y, off_HHy;
     17 
     18 #if defined(ARCH_x86) || defined(ARCH_rv)
     19   acc_t sum_re, sum_im;
     20 
     21   for (t = 0; t < NUM_TX; ++t) {
     22     for (s = 0; s < NUM_SC; ++s) {
     23       sum_re = sum_im = 0;
     24       for (r = 0; r < NUM_RX; ++r) {
     25         off_H = r * NUM_TX * NUM_SC + t * NUM_SC + s;
     26         off_y = r * NUM_SC + s;
     27         sum_re += (acc_t)g_H.re[off_H] * (acc_t)g_y.re[off_y]
     28                 - (acc_t)g_H.im[off_H] * (acc_t)g_y.im[off_y];
     29         sum_im += (acc_t)g_H.re[off_H] * (acc_t)g_y.im[off_y]
     30                 + (acc_t)g_H.im[off_H] * (acc_t)g_y.re[off_y];
     31       }
     32       off_HHy = t * NUM_SC + s;
     33 #if defined(DATA_TYPE_float)
     34       g_HHy.re[off_HHy] = (data_t)sum_re;
     35       g_HHy.im[off_HHy] = (data_t)sum_im;
     36 #elif defined(DATA_TYPE_fixed)
     37       g_HHy.re[off_HHy] = (data_t)(sum_re >> FP_Q);
     38       g_HHy.im[off_HHy] = (data_t)(sum_im >> FP_Q);
     39 #else
     40 #error "Unknown data type"
     41 #endif
     42     }
     43   }
     44 
     45 #elif defined(ARCH_rvv)
     46   size_t sz, vl;
     47 
     48   for (t = 0; t < NUM_TX; ++t) {
     49     sz = NUM_SC;
     50     s = 0;
     51     while (sz > 0) {
     52       /* Initialize HHy with 0 */
     53       /* v0 - HHy real part */
     54       /* v4 - HHy imaginary part */
     55       __asm__ volatile(
     56         "vsetvli %0, %1, e32, m4, ta, ma\n"
     57         "vmv.v.i v0, 0\n"
     58         "vmv.v.i v4, 0\n"
     59         : "=r"(vl)
     60         : "r"(sz)
     61       );
     62 
     63       for (r = 0; r < NUM_RX; ++r) {
     64         off_H = r * NUM_TX * NUM_SC + t * NUM_SC + s;
     65         off_y = r * NUM_SC + s;
     66         __asm__ volatile(
     67           "vle32.v v8, (%0)\n"
     68           "vle32.v v12, (%1)\n"
     69           "vle32.v v16, (%2)\n"
     70           "vle32.v v20, (%3)\n"
     71 #if defined(DATA_TYPE_float)
     72           /* real part */
     73           "vfmacc.vv v0, v8, v16\n"
     74           "vfnmsac.vv v0, v12, v20\n"
     75           /* imaginary part */
     76           "vfmacc.vv v4, v12, v16\n"
     77           "vfmacc.vv v4, v8, v20\n"
     78 #elif defined(DATA_TYPE_fixed)
     79           /* real part */
     80           "vsmul.vv v24, v8, v16\n"
     81           "vsadd.vv v0, v0, v24\n"
     82           "vsmul.vv v24, v12, v20\n"
     83           "vssub.vv v0, v0, v24\n"
     84           /* imaginary part */
     85           "vsmul.vv v24, v12, v16\n"
     86           "vsadd.vv v4, v4, v24\n"
     87           "vsmul.vv v24, v8, v20\n"
     88           "vsadd.vv v4, v4, v24\n"
     89 #else
     90 #error "Unknown data type"
     91 #endif
     92           :
     93           : "r"(&g_H.re[off_H]), "r"(&g_H.im[off_H]),
     94             "r"(&g_y.re[off_y]), "r"(&g_y.im[off_y])
     95         );
     96       }
     97 
     98       /* Store result */
     99       off_HHy = t * NUM_SC + s;
    100       __asm__ volatile(
    101         "vse32.v v0, (%0)\n"
    102         "vse32.v v4, (%1)\n"
    103         :
    104         : "r"(&g_HHy.re[off_HHy]), "r"(&g_HHy.im[off_HHy])
    105       );
    106 
    107       sz -= vl;
    108       s += vl;
    109     }
    110   }
    111 #else
    112 #error "Unknown architecture"
    113 #endif
    114 }