mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

mmse_nosqrt.py (1613B)


      1 #!/usr/bin/python
      2 
      3 from util import read_defines, load_xHRy
      4 
      5 import numpy as np
      6 #from scipy.linalg import ldl
      7 from os import path, makedirs
      8 
      9 # Read the data from the files
     10 NUM_RX_ANT, NUM_TX_ANT, NUM_SC = read_defines()
     11 x, H, R, y = load_xHRy(NUM_RX_ANT, NUM_TX_ANT, NUM_SC)
     12 x_mmse = np.empty((NUM_TX_ANT, NUM_SC), np.complex64)
     13 
     14 
     15 def ldl(A):
     16     n = A.shape[0]
     17     D = np.zeros_like(A)
     18     L = np.eye(n, dtype=A.dtype)
     19     for i in range(n):
     20         for j in range(i):
     21             s = 0.+0j
     22             for l in range(j):
     23                 s += L[i][l] * L[j][l].conj() * D[l][l]
     24             L[i][j] = (A[i][j] - s) / D[j][j]
     25         s = 0.+0.j
     26         for j in range(i):
     27             s += L[i][j] * L[i][j].conj() * D[j][j]
     28         D[i][i] = A[i][i] - s
     29     return L, D
     30 
     31 
     32 for sc in range(NUM_SC):
     33     HH = H[..., sc].conj().T
     34     L, D = ldl(HH @ H[..., sc] + R[..., sc])
     35     HHy = np.einsum("ij,j->i", HH, y[:, sc])
     36 
     37     z = np.empty((NUM_TX_ANT,), dtype=np.complex64)
     38     for i in range(NUM_TX_ANT):
     39         z[i] = (HHy[i] - np.sum([L[i, j]*z[j] for j in range(i)]))
     40 
     41     LH = L.conj().T
     42     for i in range(NUM_TX_ANT-1, -1, -1):
     43         x_mmse[i, sc] = z[i] / D[i, i] - np.sum([L[i, j] * x[j, sc] for j in range(i+1, NUM_TX_ANT)])
     44 
     45 # Output the data
     46 out_dir = path.join(path.dirname(__file__), "..", "out")
     47 if not path.exists(out_dir):
     48     makedirs(out_dir)
     49 with open(path.join(out_dir, "x_mmse_nosqrt_python_re.bin"), "wb") as f:
     50     f.write(x_mmse.real.astype(np.float32).tobytes())
     51 with open(path.join(out_dir, "x_mmse_nosqrt_python_im.bin"), "wb") as f:
     52     f.write(x_mmse.imag.astype(np.float32).tobytes())