mmserv

Minimum Mean Square Error detection on RISC-V Vector Extention
git clone https://git.ea.contact/mmserv
Log | Files | Refs | README

commit 9362e7014154afd7772f908b36aee6cada2b8b7b
parent 925cbd30525e39af40697446154ee081c8f74f19
Author: Egor Achkasov <eaachkasov@edu.hse.ru>
Date:   Thu, 31 Oct 2024 15:28:07 +0100

Improve path handling in scripts

Diffstat:
Mscripts/gen_data.py | 7+++++++
Mscripts/mmse.py | 6+++++-
Mscripts/util.py | 9+++++----
3 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/scripts/gen_data.py b/scripts/gen_data.py @@ -5,6 +5,7 @@ from util import read_defines, interleave import numpy as np from numpy.random import random, normal from sys import argv +from os import path, makedirs if "--help" in argv: print("Usage: python scripts/gen_data.py [--txt] [--bin] [--s]") @@ -63,6 +64,12 @@ sections = [ Section("y", "data/y.txt", "3", 16, y.size * 2), ] +# Create "data" directory if it does not exist +if WRITE_DATA_BIN or WRITE_DATA_TXT: + data_dir = path.join(path.dirname(__file__), "..", "data") + if not path.exists(data_dir): + makedirs(data_dir) + if WRITE_DATA_TXT: for data, sec in zip(data_tuple, sections): with open(sec.source, "w") as f: diff --git a/scripts/mmse.py b/scripts/mmse.py @@ -3,6 +3,7 @@ from util import read_defines, deinterleave, interleave import numpy as np +from os import path, makedirs NUM_RX_ANT, NUM_TX_ANT, NUM_SC = read_defines() @@ -40,5 +41,8 @@ y = np.moveaxis(y, 0, -1) # Cast the complex data to int16 x_mmse_data = interleave(x_mmse) -with open("out/x_mmse_python.bin", "wb") as f: +out_dir = path.join(path.dirname(__file__), "..", "out") +if not path.exists(out_dir): + makedirs(out_dir) +with open(path.join(out_dir, "x_mmse_python.bin"), "wb") as f: f.write(x_mmse_data.tobytes()) diff --git a/scripts/util.py b/scripts/util.py @@ -1,3 +1,4 @@ +from os import path import numpy as np @@ -9,14 +10,14 @@ def read_defines(): int: Number of transmit antennas int: Number of subcarriers """ - with open(__file__[:__file__.rfind("/")] + "/../inc/define.h", "r") as f: + with open(path.join(path.dirname(__file__), "..", "inc", "define.h"), "r") as f: lines = f.read().split("\n") for line in lines: - if line[:19] == "#define NUM_RX_ANT ": + if line.startswith("#define NUM_RX_ANT "): NUM_RX_ANT = int(line[19:]) - if line[:19] == "#define NUM_TX_ANT ": + if line.startswith("#define NUM_TX_ANT "): NUM_TX_ANT = int(line[19:]) - if line[:15] == "#define NUM_SC ": + if line.startswith("#define NUM_SC "): NUM_SC = int(line[15:]) # Assert that all the defines are read