mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

commit 43fb5cc3e9d9bcb9f778aafa8faa9b86b360b735
parent 8b6b928b2efe23947a62358bf75d34f31de876ab
Author: Egor Achkasov <eaachkasov@edu.hse.ru>
Date:   Fri, 24 Jan 2025 13:01:31 +0100

Replace matmul with gram matrix

Diffstat:
Msrc/mmserv.c | 228+++++++++++++++++++++++++++++++++++++++++--------------------------------------
1 file changed, 118 insertions(+), 110 deletions(-)

diff --git a/src/mmserv.c b/src/mmserv.c @@ -5,6 +5,10 @@ #include <stddef.h> /* for size_t */ +/* + * Debug + */ + #ifdef DEBUG #include "../../common/runtime.h" #include "../../common/util.h" @@ -84,23 +88,6 @@ void csqrt( * Complex matrix operations */ -void cmat_hermitian_transpose_RxTx( - IN vcomplex *A, - OUT vcomplex *AH) -{ - size_t i, j, k, off_ijk = 0, off_jik; - for (i = 0; i < NUM_RX_ANT; ++i) - for (j = 0; j < NUM_TX_ANT; ++j) { - off_jik = j * NUM_RX_ANT * NUM_SC + i * NUM_SC; - for (k = 0; k < NUM_SC; ++k) { - AH->re[off_jik] = A->re[off_ijk]; - AH->im[off_jik] = -A->im[off_ijk]; - ++off_ijk; - ++off_jik; - } - } -} - void cmat_hermitian_transpose_TxTx( IN vcomplex *A, OUT vcomplex *AH) @@ -118,98 +105,127 @@ void cmat_hermitian_transpose_TxTx( } } -void cmatmul_TxRx_RxTx( +/** Complex Gram matrix A^H*A and add complex matrix R (A^H*A + R) + * \param A matrix of channel coefficients. Shape [NUM_RX_ANT][NUM_TX_ANT][NUM_SC] + * \param R noise covariance matrix. Shape [NUM_TX_ANT][NUM_TX_ANT][NUM_SC] + * \param result output Gram matrix. Shape [NUM_TX_ANT][NUM_TX_ANT][NUM_SC] + */ +void cmatgram_TxRx_cadd( IN vcomplex *A, - IN vcomplex *B, + IN vcomplex *R, OUT vcomplex *result) { - size_t i, j, k, l; - size_t off_ijk = 0, off_ilk, off_ljk; - data_t A_re, A_im, B_re, B_im; - - for (i = 0; i < NUM_TX_ANT * NUM_TX_ANT * NUM_SC; ++i) - result->re[i] = result->im[i] = 0.f; - - for (i = 0; i != NUM_TX_ANT; ++i) { - for (j = 0; j != NUM_RX_ANT; ++j) { - for (k = 0; k != NUM_SC; ++k) { - off_ilk = i * NUM_RX_ANT * NUM_SC + k; - off_ljk = j * NUM_SC + k; - for (l = 0; l != NUM_RX_ANT; ++l) { - A_re = A->re[off_ilk]; - A_im = A->im[off_ilk]; - B_re = B->re[off_ljk]; - B_im = B->im[off_ljk]; - result->re[off_ijk] += A_re * B_re - A_im * B_im; - result->im[off_ijk] += A_re * B_im + A_im * B_re; - off_ilk += NUM_SC; - off_ljk += NUM_TX_ANT * NUM_SC; - } - ++off_ijk; + size_t t1, t2, r; + size_t sz, vl; + size_t off_sc, off_A, off_AH; + size_t off_result_L, off_result_U; + vfloat32m1_t vA_re, vA_im, vAH_re, vAH_im; + vfloat32m1_t vR; + vfloat32m1_t vresult_re, vresult_im; + vfloat32m1_t vt; + + printf("A:\n["); + for (size_t i = 0; i < NUM_TX_ANT; ++i) { + printf("["); + for (size_t j = 0; j < NUM_RX_ANT; ++j) { + for (size_t k = 0; k < NUM_SC; ++k) { + printf("\t%f %+fj,", A->re[i * NUM_RX_ANT * NUM_SC + j * NUM_SC + k], A->im[i * NUM_RX_ANT * NUM_SC + j * NUM_SC + k]); } } + printf("],\n"); } -} - -void cmatmul_TxTx_TxTx( - IN vcomplex *A, - IN vcomplex *B, - OUT vcomplex *result) -{ - size_t i, j, k, l; - size_t off_ijk = 0, off_ilk, off_ljk; - data_t A_re, A_im, B_re, B_im; - - for (i = 0; i < NUM_TX_ANT * NUM_TX_ANT * NUM_SC; ++i) - result->re[i] = result->im[i] = 0.f; + printf("]\n"); + + printf("R:\n["); + for (size_t i = 0; i < NUM_TX_ANT; ++i) { + printf("["); + for (size_t j = 0; j < NUM_TX_ANT; ++j) { + for (size_t k = 0; k < NUM_SC; ++k) { + printf("\t%f %fj,", R->re[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k], R->im[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k]); + } + } + printf("],\n"); + } + printf("]\n"); + + for (t1 = 0; t1 != NUM_TX_ANT; ++t1) + for (t2 = 0; t2 != NUM_TX_ANT; ++t2) + for (r = 0; r != NUM_SC; ++r) + result->re[t1 * NUM_TX_ANT * NUM_SC + t2 * NUM_SC + r] + = result->im[t1 * NUM_TX_ANT * NUM_SC + t2 * NUM_SC + r] = 0.f; + + for (t1 = 0; t1 != NUM_TX_ANT; ++t1) + for (t2 = t1; t2 != NUM_TX_ANT; ++t2) { + off_sc = 0; + off_result_L = t1 * NUM_TX_ANT * NUM_SC + t2 * NUM_SC; + off_result_U = t2 * NUM_TX_ANT * NUM_SC + t1 * NUM_SC; + sz = NUM_SC; + while (sz > 0) { + vl = vsetvl_e32m1(sz); + vresult_re = vfmv_v_f_f32m1(0.f, vl); + vresult_im = vfmv_v_f_f32m1(0.f, vl); + + for (r = 0; r != NUM_RX_ANT; ++r) { + off_A = r * NUM_TX_ANT * NUM_SC + t1 * NUM_SC + off_sc; + off_AH = r * NUM_TX_ANT * NUM_SC + t2 * NUM_SC + off_sc; + vA_re = vle32_v_f32m1(&A->re[off_A], vl); + vA_im = vle32_v_f32m1(&A->im[off_A], vl); + vAH_re = vle32_v_f32m1(&A->re[off_AH], vl); + vAH_im = vle32_v_f32m1(&A->im[off_AH], vl); + + /* real part */ + vt = vfmul_vv_f32m1(vA_re, vAH_re, vl); + vresult_re = vfadd_vv_f32m1(vresult_re, vt, vl); + vt = vfmul_vv_f32m1(vA_im, vAH_im, vl); + vresult_re = vfadd_vv_f32m1(vresult_re, vt, vl); + + /* imaginary part */ + vt = vfmul_vv_f32m1(vA_im, vAH_re, vl); + vresult_im = vfadd_vv_f32m1(vresult_im, vt, vl); + vt = vfmul_vv_f32m1(vA_re, vAH_im, vl); + vresult_im = vfsub_vv_f32m1(vresult_im, vt, vl); + } - for (i = 0; i != NUM_TX_ANT; ++i) { - for (j = 0; j != NUM_TX_ANT; ++j) { - for (k = 0; k != NUM_SC; ++k) { - off_ilk = i * NUM_TX_ANT * NUM_SC + k; - off_ljk = j * NUM_SC + k; - for (l = 0; l != NUM_TX_ANT; ++l) { - A_re = A->re[off_ilk]; - A_im = A->im[off_ilk]; - B_re = B->re[off_ljk]; - B_im = B->im[off_ljk]; - result->re[off_ijk] += A_re * B_re - A_im * B_im; - result->im[off_ijk] += A_re * B_im + A_im * B_re; - off_ilk += NUM_SC; - off_ljk += NUM_TX_ANT * NUM_SC; + /* Upper triangle */ + /* Add R */ + vR = vle32_v_f32m1(&R->re[off_result_U], vl); + vresult_re = vfadd_vv_f32m1(vresult_re, vR, vl); + vR = vle32_v_f32m1(&R->im[off_result_U], vl); + vresult_im = vfadd_vv_f32m1(vresult_im, vR, vl); + /* Store result */ + vse32_v_f32m1(&result->re[off_result_U], vresult_re, vl); + vse32_v_f32m1(&result->im[off_result_U], vresult_im, vl); + + /* Lower triangle */ + if (t1 != t2) { + /* Conjugate (as result_ij = result_ji*) */ + vresult_im = vfneg_v_f32m1(vresult_im, vl); + /* Add R */ + vR = vle32_v_f32m1(&R->re[off_result_L], vl); + vresult_re = vfadd_vv_f32m1(vresult_re, vR, vl); + vR = vle32_v_f32m1(&R->im[off_result_L], vl); + vresult_im = vfadd_vv_f32m1(vresult_im, vR, vl); + /* Store result */ + vse32_v_f32m1(&result->re[off_result_L], vresult_re, vl); + vse32_v_f32m1(&result->im[off_result_L], vresult_im, vl); } - ++off_ijk; + + sz -= vl; + off_sc += vl; } } - } -} -void cmatadd_TxTx( - IN vcomplex *A, - IN vcomplex *B, - OUT vcomplex *result) -{ - vfloat32m1_t vA, vB, vresult; - size_t vl, sz = NUM_TX_ANT * NUM_TX_ANT * NUM_SC; - size_t off = 0; - while (sz > 0) { - vl = vsetvl_e32m1(sz); - - /* real part */ - vA = vle32_v_f32m1(&A->re[off], vl); - vB = vle32_v_f32m1(&B->re[off], vl); - vresult = vfadd_vv_f32m1(vA, vB, vl); - vse32_v_f32m1(&result->re[off], vresult, vl); - - /* imaginary part */ - vA = vle32_v_f32m1(&A->im[off], vl); - vB = vle32_v_f32m1(&B->im[off], vl); - vresult = vfadd_vv_f32m1(vA, vB, vl); - vse32_v_f32m1(&result->im[off], vresult, vl); - - sz -= vl; - off += vl; + printf("result:\n["); + for (size_t i = 0; i < NUM_TX_ANT; ++i) { + printf("["); + for (size_t j = 0; j < NUM_TX_ANT; ++j) { + for (size_t k = 0; k < NUM_SC; ++k) { + printf("\t%f %fj", result->re[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k], result->im[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k]); + } + } + printf("]\n"); } + printf("]\n"); } void ccholesky_TxTx( @@ -490,21 +506,11 @@ void mmse( LH.re = (data_t *)LH_re; LH.im = (data_t *)LH_im; - /* H^H */ - TIME( - "Hermitian transpose (RxTx): %ld\n", - cmat_hermitian_transpose_RxTx, - H, &HH); - /* H^H*H */ - TIME( - "Matmul (TxRx x RxTx): %ld\n", - cmatmul_TxRx_RxTx, - &HH, H, &HH_H); /* H^H*H + R */ TIME( - "Matadd (TxTx + TxTx): %ld\n", - cmatadd_TxTx, - &HH_H, R, &HH_H); + "Gram and add (RxTx x TxRx + TxTx): %ld\n", + cmatgram_TxRx_cadd, + H, R, &HH_H); /* L: (H^H*H + R) = L*L^H */ TIME( "Cholesky (TxTx): %ld\n", @@ -569,6 +575,7 @@ void mmse_nosqrt( LH.im = (data_t *)LH_im; /* H^H */ + #if 0 TIME( "Hermitian transpose (RxTx): %ld\n", cmat_hermitian_transpose_RxTx, @@ -618,6 +625,7 @@ void mmse_nosqrt( "Backward substitution (TxTx): %ld\n", cbackwardsub_TxTx, &LH, &z, x_MMSE); + #endif } acc_t mse(