v.py (2655B)
1 from typing import Tuple 2 import numpy as np 3 import matplotlib.pyplot as plt 4 5 from sys import argv 6 7 if len(argv) != 2: 8 raise RuntimeError(f"Usage: python {argv[0]} path") 9 if argv[1][-1] != '/': 10 argv[1] += '/' 11 12 13 def calc_material_parameters( 14 w: np.ndarray, 15 v: np.ndarray, 16 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 17 """Calculate the material parameters from the read-out vectors and embeddings.""" 18 dot = np.sum(v * w, axis=-1) 19 cond = np.exp(dot[0]) 20 perm = np.exp(dot[1]) + 1. 21 scat = 1 / (1 + np.exp(-dot[2])) 22 xpd = 1 / (1 + np.exp(-dot[3])) 23 return cond, perm, scat, xpd 24 25 26 trainable_materials_names = [ 27 "itu_wood_train", 28 "itu_plasterboard_train", 29 "itu_metal_train", 30 "itu_concrete_train", 31 ] 32 num_materials = len(trainable_materials_names) 33 34 omegas = np.load(argv[1] + "omegas.npy") 35 omegas = omegas.reshape(omegas.shape[0], 2, 4, num_materials, -1) 36 print("Loaded omegas: ", omegas.shape) 37 38 loss = np.load(argv[1] + "losses.npy") 39 print("Loaded loss: ", loss.shape) 40 41 v = omegas[:, 0, :, :, :] # Shape: (num_iterations, 4, num_materials, num_embeddings) 42 w = omegas[:, 1, :, :, :] # Shape: (num_iterations, 4, num_materials, num_embeddings) 43 44 # Initialize arrays to store parameters for all iterations 45 num_iterations = v.shape[0] 46 cond_all = np.zeros((num_iterations, num_materials)) 47 perm_all = np.zeros((num_iterations, num_materials)) 48 scat_all = np.zeros((num_iterations, num_materials)) 49 xpd_all = np.zeros((num_iterations, num_materials)) 50 51 # Calculate parameters for all iterations 52 for i in range(num_iterations): 53 cond, perm, scat, xpd = calc_material_parameters(w[i], v[i]) 54 cond_all[i] = cond 55 perm_all[i] = perm 56 scat_all[i] = scat 57 xpd_all[i] = xpd 58 59 # Plotting 60 fig, ax = plt.subplots(4, num_materials + 1, figsize=(12, 8)) 61 fig.delaxes(ax[1, -1]) 62 fig.delaxes(ax[2, -1]) 63 fig.delaxes(ax[3, -1]) 64 fig.suptitle(f"Material parameters over iterations ({argv[1]})") 65 ax[0, 0].set_ylabel("Conductivity (S/m)") 66 ax[1, 0].set_ylabel("Relative Permittivity") 67 ax[2, 0].set_ylabel("Scattering Coefficient") 68 ax[3, 0].set_ylabel("XPD Coefficient") 69 for a in ax.flat: 70 a.set_xlabel("Iterations") 71 72 # Plot for each material 73 for i, mat_name in enumerate(trainable_materials_names): 74 ax[0, i].set_title(mat_name) 75 ax[0, i].plot(cond_all[:, i], label=mat_name) 76 ax[1, i].plot(perm_all[:, i], label=mat_name) 77 ax[2, i].plot(scat_all[:, i], label=mat_name) 78 ax[3, i].plot(xpd_all[:, i], label=mat_name) 79 80 # Plot loss 81 ax[0, -1].plot(loss, label="Loss", color='black', linestyle='--') 82 ax[0, -1].set_title("Overall loss") 83 ax[0, -1].set_ylabel("Loss") 84 85 plt.tight_layout() 86 plt.show()