commit 43fb5cc3e9d9bcb9f778aafa8faa9b86b360b735
parent 8b6b928b2efe23947a62358bf75d34f31de876ab
Author: Egor Achkasov <eaachkasov@edu.hse.ru>
Date: Fri, 24 Jan 2025 13:01:31 +0100
Replace matmul with gram matrix
Diffstat:
| M | src/mmserv.c | | | 228 | +++++++++++++++++++++++++++++++++++++++++-------------------------------------- |
1 file changed, 118 insertions(+), 110 deletions(-)
diff --git a/src/mmserv.c b/src/mmserv.c
@@ -5,6 +5,10 @@
#include <stddef.h> /* for size_t */
+/*
+ * Debug
+ */
+
#ifdef DEBUG
#include "../../common/runtime.h"
#include "../../common/util.h"
@@ -84,23 +88,6 @@ void csqrt(
* Complex matrix operations
*/
-void cmat_hermitian_transpose_RxTx(
- IN vcomplex *A,
- OUT vcomplex *AH)
-{
- size_t i, j, k, off_ijk = 0, off_jik;
- for (i = 0; i < NUM_RX_ANT; ++i)
- for (j = 0; j < NUM_TX_ANT; ++j) {
- off_jik = j * NUM_RX_ANT * NUM_SC + i * NUM_SC;
- for (k = 0; k < NUM_SC; ++k) {
- AH->re[off_jik] = A->re[off_ijk];
- AH->im[off_jik] = -A->im[off_ijk];
- ++off_ijk;
- ++off_jik;
- }
- }
-}
-
void cmat_hermitian_transpose_TxTx(
IN vcomplex *A,
OUT vcomplex *AH)
@@ -118,98 +105,127 @@ void cmat_hermitian_transpose_TxTx(
}
}
-void cmatmul_TxRx_RxTx(
+/** Complex Gram matrix A^H*A and add complex matrix R (A^H*A + R)
+ * \param A matrix of channel coefficients. Shape [NUM_RX_ANT][NUM_TX_ANT][NUM_SC]
+ * \param R noise covariance matrix. Shape [NUM_TX_ANT][NUM_TX_ANT][NUM_SC]
+ * \param result output Gram matrix. Shape [NUM_TX_ANT][NUM_TX_ANT][NUM_SC]
+ */
+void cmatgram_TxRx_cadd(
IN vcomplex *A,
- IN vcomplex *B,
+ IN vcomplex *R,
OUT vcomplex *result)
{
- size_t i, j, k, l;
- size_t off_ijk = 0, off_ilk, off_ljk;
- data_t A_re, A_im, B_re, B_im;
-
- for (i = 0; i < NUM_TX_ANT * NUM_TX_ANT * NUM_SC; ++i)
- result->re[i] = result->im[i] = 0.f;
-
- for (i = 0; i != NUM_TX_ANT; ++i) {
- for (j = 0; j != NUM_RX_ANT; ++j) {
- for (k = 0; k != NUM_SC; ++k) {
- off_ilk = i * NUM_RX_ANT * NUM_SC + k;
- off_ljk = j * NUM_SC + k;
- for (l = 0; l != NUM_RX_ANT; ++l) {
- A_re = A->re[off_ilk];
- A_im = A->im[off_ilk];
- B_re = B->re[off_ljk];
- B_im = B->im[off_ljk];
- result->re[off_ijk] += A_re * B_re - A_im * B_im;
- result->im[off_ijk] += A_re * B_im + A_im * B_re;
- off_ilk += NUM_SC;
- off_ljk += NUM_TX_ANT * NUM_SC;
- }
- ++off_ijk;
+ size_t t1, t2, r;
+ size_t sz, vl;
+ size_t off_sc, off_A, off_AH;
+ size_t off_result_L, off_result_U;
+ vfloat32m1_t vA_re, vA_im, vAH_re, vAH_im;
+ vfloat32m1_t vR;
+ vfloat32m1_t vresult_re, vresult_im;
+ vfloat32m1_t vt;
+
+ printf("A:\n[");
+ for (size_t i = 0; i < NUM_TX_ANT; ++i) {
+ printf("[");
+ for (size_t j = 0; j < NUM_RX_ANT; ++j) {
+ for (size_t k = 0; k < NUM_SC; ++k) {
+ printf("\t%f %+fj,", A->re[i * NUM_RX_ANT * NUM_SC + j * NUM_SC + k], A->im[i * NUM_RX_ANT * NUM_SC + j * NUM_SC + k]);
}
}
+ printf("],\n");
}
-}
-
-void cmatmul_TxTx_TxTx(
- IN vcomplex *A,
- IN vcomplex *B,
- OUT vcomplex *result)
-{
- size_t i, j, k, l;
- size_t off_ijk = 0, off_ilk, off_ljk;
- data_t A_re, A_im, B_re, B_im;
-
- for (i = 0; i < NUM_TX_ANT * NUM_TX_ANT * NUM_SC; ++i)
- result->re[i] = result->im[i] = 0.f;
+ printf("]\n");
+
+ printf("R:\n[");
+ for (size_t i = 0; i < NUM_TX_ANT; ++i) {
+ printf("[");
+ for (size_t j = 0; j < NUM_TX_ANT; ++j) {
+ for (size_t k = 0; k < NUM_SC; ++k) {
+ printf("\t%f %fj,", R->re[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k], R->im[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k]);
+ }
+ }
+ printf("],\n");
+ }
+ printf("]\n");
+
+ for (t1 = 0; t1 != NUM_TX_ANT; ++t1)
+ for (t2 = 0; t2 != NUM_TX_ANT; ++t2)
+ for (r = 0; r != NUM_SC; ++r)
+ result->re[t1 * NUM_TX_ANT * NUM_SC + t2 * NUM_SC + r]
+ = result->im[t1 * NUM_TX_ANT * NUM_SC + t2 * NUM_SC + r] = 0.f;
+
+ for (t1 = 0; t1 != NUM_TX_ANT; ++t1)
+ for (t2 = t1; t2 != NUM_TX_ANT; ++t2) {
+ off_sc = 0;
+ off_result_L = t1 * NUM_TX_ANT * NUM_SC + t2 * NUM_SC;
+ off_result_U = t2 * NUM_TX_ANT * NUM_SC + t1 * NUM_SC;
+ sz = NUM_SC;
+ while (sz > 0) {
+ vl = vsetvl_e32m1(sz);
+ vresult_re = vfmv_v_f_f32m1(0.f, vl);
+ vresult_im = vfmv_v_f_f32m1(0.f, vl);
+
+ for (r = 0; r != NUM_RX_ANT; ++r) {
+ off_A = r * NUM_TX_ANT * NUM_SC + t1 * NUM_SC + off_sc;
+ off_AH = r * NUM_TX_ANT * NUM_SC + t2 * NUM_SC + off_sc;
+ vA_re = vle32_v_f32m1(&A->re[off_A], vl);
+ vA_im = vle32_v_f32m1(&A->im[off_A], vl);
+ vAH_re = vle32_v_f32m1(&A->re[off_AH], vl);
+ vAH_im = vle32_v_f32m1(&A->im[off_AH], vl);
+
+ /* real part */
+ vt = vfmul_vv_f32m1(vA_re, vAH_re, vl);
+ vresult_re = vfadd_vv_f32m1(vresult_re, vt, vl);
+ vt = vfmul_vv_f32m1(vA_im, vAH_im, vl);
+ vresult_re = vfadd_vv_f32m1(vresult_re, vt, vl);
+
+ /* imaginary part */
+ vt = vfmul_vv_f32m1(vA_im, vAH_re, vl);
+ vresult_im = vfadd_vv_f32m1(vresult_im, vt, vl);
+ vt = vfmul_vv_f32m1(vA_re, vAH_im, vl);
+ vresult_im = vfsub_vv_f32m1(vresult_im, vt, vl);
+ }
- for (i = 0; i != NUM_TX_ANT; ++i) {
- for (j = 0; j != NUM_TX_ANT; ++j) {
- for (k = 0; k != NUM_SC; ++k) {
- off_ilk = i * NUM_TX_ANT * NUM_SC + k;
- off_ljk = j * NUM_SC + k;
- for (l = 0; l != NUM_TX_ANT; ++l) {
- A_re = A->re[off_ilk];
- A_im = A->im[off_ilk];
- B_re = B->re[off_ljk];
- B_im = B->im[off_ljk];
- result->re[off_ijk] += A_re * B_re - A_im * B_im;
- result->im[off_ijk] += A_re * B_im + A_im * B_re;
- off_ilk += NUM_SC;
- off_ljk += NUM_TX_ANT * NUM_SC;
+ /* Upper triangle */
+ /* Add R */
+ vR = vle32_v_f32m1(&R->re[off_result_U], vl);
+ vresult_re = vfadd_vv_f32m1(vresult_re, vR, vl);
+ vR = vle32_v_f32m1(&R->im[off_result_U], vl);
+ vresult_im = vfadd_vv_f32m1(vresult_im, vR, vl);
+ /* Store result */
+ vse32_v_f32m1(&result->re[off_result_U], vresult_re, vl);
+ vse32_v_f32m1(&result->im[off_result_U], vresult_im, vl);
+
+ /* Lower triangle */
+ if (t1 != t2) {
+ /* Conjugate (as result_ij = result_ji*) */
+ vresult_im = vfneg_v_f32m1(vresult_im, vl);
+ /* Add R */
+ vR = vle32_v_f32m1(&R->re[off_result_L], vl);
+ vresult_re = vfadd_vv_f32m1(vresult_re, vR, vl);
+ vR = vle32_v_f32m1(&R->im[off_result_L], vl);
+ vresult_im = vfadd_vv_f32m1(vresult_im, vR, vl);
+ /* Store result */
+ vse32_v_f32m1(&result->re[off_result_L], vresult_re, vl);
+ vse32_v_f32m1(&result->im[off_result_L], vresult_im, vl);
}
- ++off_ijk;
+
+ sz -= vl;
+ off_sc += vl;
}
}
- }
-}
-void cmatadd_TxTx(
- IN vcomplex *A,
- IN vcomplex *B,
- OUT vcomplex *result)
-{
- vfloat32m1_t vA, vB, vresult;
- size_t vl, sz = NUM_TX_ANT * NUM_TX_ANT * NUM_SC;
- size_t off = 0;
- while (sz > 0) {
- vl = vsetvl_e32m1(sz);
-
- /* real part */
- vA = vle32_v_f32m1(&A->re[off], vl);
- vB = vle32_v_f32m1(&B->re[off], vl);
- vresult = vfadd_vv_f32m1(vA, vB, vl);
- vse32_v_f32m1(&result->re[off], vresult, vl);
-
- /* imaginary part */
- vA = vle32_v_f32m1(&A->im[off], vl);
- vB = vle32_v_f32m1(&B->im[off], vl);
- vresult = vfadd_vv_f32m1(vA, vB, vl);
- vse32_v_f32m1(&result->im[off], vresult, vl);
-
- sz -= vl;
- off += vl;
+ printf("result:\n[");
+ for (size_t i = 0; i < NUM_TX_ANT; ++i) {
+ printf("[");
+ for (size_t j = 0; j < NUM_TX_ANT; ++j) {
+ for (size_t k = 0; k < NUM_SC; ++k) {
+ printf("\t%f %fj", result->re[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k], result->im[i * NUM_TX_ANT * NUM_SC + j * NUM_SC + k]);
+ }
+ }
+ printf("]\n");
}
+ printf("]\n");
}
void ccholesky_TxTx(
@@ -490,21 +506,11 @@ void mmse(
LH.re = (data_t *)LH_re;
LH.im = (data_t *)LH_im;
- /* H^H */
- TIME(
- "Hermitian transpose (RxTx): %ld\n",
- cmat_hermitian_transpose_RxTx,
- H, &HH);
- /* H^H*H */
- TIME(
- "Matmul (TxRx x RxTx): %ld\n",
- cmatmul_TxRx_RxTx,
- &HH, H, &HH_H);
/* H^H*H + R */
TIME(
- "Matadd (TxTx + TxTx): %ld\n",
- cmatadd_TxTx,
- &HH_H, R, &HH_H);
+ "Gram and add (RxTx x TxRx + TxTx): %ld\n",
+ cmatgram_TxRx_cadd,
+ H, R, &HH_H);
/* L: (H^H*H + R) = L*L^H */
TIME(
"Cholesky (TxTx): %ld\n",
@@ -569,6 +575,7 @@ void mmse_nosqrt(
LH.im = (data_t *)LH_im;
/* H^H */
+ #if 0
TIME(
"Hermitian transpose (RxTx): %ld\n",
cmat_hermitian_transpose_RxTx,
@@ -618,6 +625,7 @@ void mmse_nosqrt(
"Backward substitution (TxTx): %ld\n",
cbackwardsub_TxTx,
&LH, &z, x_MMSE);
+ #endif
}
acc_t mse(