Skip to content

Commit

Permalink
scatter plots over long sequence look usable
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Jan 4, 2025
1 parent 53dedf0 commit 9481435
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions est_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
os.environ['CUDA_VISIBLE_DEVICES'] = ""
device = torch.device("cpu")

def snr_est_test(model, snr_target, h, Nw, test_S1=False, genie_phase=True):
def snr_est_test(model, snr_target, h, Nw, test_S1=False):

Nc = model.Nc
Pc = np.array(model.pilot_gain*model.P)
Expand All @@ -63,6 +63,7 @@ def snr_est_test(model, snr_target, h, Nw, test_S1=False, genie_phase=True):
Pcn_hat = h*Pcn + n

# phase corrected received pilots
genie_phase = not args.eq_ls
if genie_phase:
Rcn_hat = np.abs(h)*Pcn + n
else:
Expand All @@ -74,7 +75,8 @@ def snr_est_test(model, snr_target, h, Nw, test_S1=False, genie_phase=True):
rx_phase = np.angle(rx_pilots)
#print(rx_phase.shape)
#print(rx_phase)
Rcn_hat = Pcn_hat*np.exp(-1j*rx_phase)
Rcn_hat = Pcn_hat *np.exp(-1j*rx_phase)

if args.plots:
plt.figure(1)
plt.plot(Rcn_hat.real, Rcn_hat.imag,'b+')
Expand Down Expand Up @@ -119,31 +121,31 @@ def snr_est_test(model, snr_target, h, Nw, test_S1=False, genie_phase=True):
print("")

# single timestep test
def single(snrdB, h, Nw, test_S1, genie_phase):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h, Nw, test_S1, genie_phase)
def single(snrdB, h, Nw, test_S1):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h, Nw, test_S1)
print(f"snrdB: {snrdB:5.2f} snrdB_check: {10*np.log10(snr_check):5.2f} snrdB_est: {10*np.log10(snr_est):5.2f}")

# run over a sequence of timesteps, and return mean
# run over a sequence of timesteps, and return lists of each each est
def sequence(Ntimesteps, snrdB, h, Nw):
sum_snrdB_est = 0
sum_snrdB_check = 0
snrdB_est_list = []
snrdB_check_list = []

for i in range(Ntimesteps):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h[i*Nw:(i+1)*Nw,:], Nw)
snrdB_check = 10*np.log10(snr_check)
snrdB_est = 10*np.log10(snr_est)
print(f"snrdB: {snrdB:5.2f} snrdB_check: {snrdB_check:5.2f} snrdB_est: {snrdB_est:5.2f}")
sum_snrdB_est += snrdB_est
sum_snrdB_check += snrdB_check
snrdB_est_list = np.append(snrdB_est_list, snrdB_est)
snrdB_check_list = np.append(snrdB_check_list, snrdB_check)

return sum_snrdB_check/Ntimesteps, sum_snrdB_est/Ntimesteps
return snrdB_est_list, snrdB_check_list

# sweep across SNRs
def sweep(Ntimesteps, h, Nw):

EsNodB_check = []
EsNodB_est = []
r = range(-5,15)
r = range(-5,20)
for aEsNodB in r:
aEsNodB_check, aEsNodB_est = sequence(Ntimesteps, aEsNodB, h, Nw)
EsNodB_check = np.append(EsNodB_check, aEsNodB_check)
Expand All @@ -152,7 +154,10 @@ def sweep(Ntimesteps, h, Nw):
plt.figure(1)
plt.plot(EsNodB_check, EsNodB_est,'b+')
plt.plot(r,r)
plt.axis([-5, 20, -5, 20])
plt.grid()
plt.xlabel('SNR (dB)')
plt.ylabel('SNR est (dB)')
plt.show()

# save test file of test points for Latex plotting in Octave radae_plots.m:est_snr_plot()
Expand All @@ -165,7 +170,7 @@ def sweep(Ntimesteps, h, Nw):
parser.add_argument('--sequence', action='store_true', help='run over a sequence of timesteps')
parser.add_argument('--h_file', type=str, default="", help='path to rate Rs multipath samples, rate Rs time steps by Nc carriers .f32 format')
parser.add_argument('-T', type=float, default=1.0, help='length of time window for estimate (default 1.0 sec)')
parser.add_argument('--Nt', type=int, default=1, help='number of analysis time windows to average (default 1)')
parser.add_argument('--Nt', type=int, default=1, help='number of analysis time windows to test across (default 1)')
parser.add_argument('--test_S1', action='store_true', help='calculate S1 two ways to check S1 expression')
parser.add_argument('--eq_ls', action='store_true', help='est phase from received pilots usin least square (default genie phase)')
parser.add_argument('--plots', action='store_true', help='debug plots (default off)')
Expand All @@ -183,7 +188,7 @@ def sweep(Ntimesteps, h, Nw):
h = np.ones((Nw*args.Nt,model.Nc))

if args.single:
single(args.snrdB, h, Nw, args.test_S1, not args.eq_ls)
single(args.snrdB, h, Nw, args.test_S1)
elif args.sequence:
sequence(args.Nt, args.snrdB, h, Nw)
else:
Expand Down

0 comments on commit 9481435

Please sign in to comment.