Skip to content

Commit

Permalink
tune v5 and plot v1
Browse files Browse the repository at this point in the history
  • Loading branch information
JunangWang committed Mar 21, 2024
1 parent 809df63 commit 2cc9178
Show file tree
Hide file tree
Showing 6 changed files with 599 additions and 211 deletions.
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"cSpell.words": [
"denorm"
]
}
406 changes: 360 additions & 46 deletions Modeling eMNS/Generative_model_v2.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Modeling eMNS/Neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import numpy as np
# set up dataset class
class eMNS_Dataset(torch.utils.data.Dataset):
def __init__(self,train_x,train_y):
def __init__(self,x,y):
#data loading
self.x = train_x
self.y = train_y
self.x = x
self.y = y
self.n_samples = self.x.shape[0]


Expand Down
182 changes: 33 additions & 149 deletions Modeling eMNS/Training_loop_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn.parallel import DistributedDataParallel
import torch.nn.functional as F
from early_stopping import EarlyStopping, EarlyDecay
from utils import compute_discrete_curl, denorm, max_min_norm, denorm_ray
from utils import Jacobian3, grad_loss_Jacobain, check_rmse_CNN, compute_discrete_curl, compute_discrete_divergence, grad_loss, get_mean_of_dataloader, gridData_reshape
from Neural_network import ResidualEMNSBlock_3d, BigBlock, Generative_net
import numpy as np
import ray
Expand Down Expand Up @@ -84,7 +84,6 @@ def train_GM(config):
"""
#---------------unpack config---------------------
# print(config)
batch_size = config['batch_size']
epochs = config["epochs"]
verbose = config['verbose']
lr_max = config['lr_max']
Expand All @@ -95,25 +94,14 @@ def train_GM(config):
learning_rate_decay = config['learning_rate_decay']
maxB = config['maxB']
minB = config['minB']
skip_spacing = config['skip_spacing']
num_repeat = config['num_repeat']
num_block = config['num_block']
device = config['device']
train_set = config['train_set']
valid_set = config['valid_set']

####################################################
#--------------model construction------------------
####################################################
num_input = 12
output_shape = (3,16,16,16)
SB_args = (64,64,skip_spacing,num_repeat) # (Cin, Cout, skip_spacing, num_repeat)
BB_args = (2,num_block) # (scale_factor, num_block)
SB_block = ResidualEMNSBlock_3d
BB_block = BigBlock


model = Generative_net(SB_args, BB_args, SB_block, BB_block, num_input=num_input, output_shape= output_shape)
model = construct_model_GM(config)



Expand Down Expand Up @@ -226,20 +214,24 @@ def train_GM(config):
print()
adjust_epoch_count += 1

# # create checkpoint
# base_model = (model.module
# if isinstance(model, DistributedDataParallel) else model)
# checkpoint_dir = tempfile.mkdtemp()
# # load back training state
# checkpoint_data = {
# "epoch": epoch,
# "net_state_dict": base_model.state_dict(),
# "optimizer_state_dict": optimizer.state_dict(),
# }
# torch.save(checkpoint_data, os.path.join(checkpoint_dir, "model.pt"))
# checkpoint = Checkpoint.from_directory(checkpoint_dir)
#Send the current training result back to Tune
train.report({'rmse_val':rmse_val.item(), 'rmse_train': rmse.item(), 'loss':loss.item()})
if epoch % (epochs-1) == 0:
# create checkpoint only at the begin and the end of epochs
base_model = (model.module
if isinstance(model, DistributedDataParallel) else model)
checkpoint_dir = tempfile.mkdtemp()
# load back training state
checkpoint_data = {
"epoch": epoch,
"net_state_dict": base_model.state_dict(),
'model': base_model
# "optimizer_state_dict": optimizer.state_dict(),
}
torch.save(checkpoint_data, os.path.join(checkpoint_dir, "model.pt"))
checkpoint = Checkpoint.from_directory(checkpoint_dir)
# Send the current training result back to Tune
train.report({'rmse_val':rmse_val.item(), 'rmse_train': rmse.item(), 'loss':loss.item()},checkpoint=checkpoint)
else:
train.report({'rmse_val':rmse_val.item(), 'rmse_train': rmse.item(), 'loss':loss.item()})



Expand All @@ -250,131 +242,23 @@ def train_GM(config):

return rmse_history, rmse_val_history,loss_history, iter_history,mse_history, mse_val_history,epoch_stop,Rsquare
#-------------------------------------------------------------------------------------------------------
def construct_model_GM(config):
num_input = config['num_input']
skip_spacing = config['skip_spacing']
num_repeat = config['num_repeat']
num_block = config['num_block']
output_shape = (3,16,16,16)
SB_args = (64,64,skip_spacing,num_repeat) # (Cin, Cout, skip_spacing, num_repeat)
BB_args = (2,num_block) # (scale_factor, num_block)
SB_block = ResidualEMNSBlock_3d
BB_block = BigBlock


model = Generative_net(SB_args, BB_args, SB_block, BB_block, num_input=num_input, output_shape= output_shape)
return model

def get_mean_of_dataloader(dataloader,model,device):
num_samples = 0
b = torch.zeros(1,device=device)
model.eval()
for x,y in dataloader:
y = y.to(device=device,dtype=torch.float)
# use sum instead of mean, what do you think?
y_sum = y.sum(dim=0,keepdim=True)
num_samples += y.shape[0]
# print(y.shape[0])
b =b+y_sum
return b/num_samples


def check_rmse_CNN(dataloader,model, grid_space, device, DF, maxB=[],minB=[]):
'''
Check RMSE of CNN
'''
mse_temp = 0
R_temp=0
Rsquare=0
num_samples = 0
# print(Bfield_mean)

data = next(iter(dataloader))
mean = data[0].mean()

Bfield_mean=get_mean_of_dataloader(dataloader,model,device)

model.eval() # set model to evaluation model

with torch.no_grad():
for x,y in dataloader:
x = x.to(device=device,dtype=torch.float)
y = y.to(device=device,dtype=torch.float)
num_samples += x.shape[0]
if DF:
_, scores = Jacobian3(model(x))
else:
scores = model(x)

# compute mse and R2 by de-normalize data
mse_temp += F.mse_loss(1e3*denorm(scores,maxB,minB,device), 1e3*denorm(y,maxB,minB, device) ,reduction='sum')
R_temp += F.mse_loss(1e3*denorm(Bfield_mean.expand_as(y),maxB,minB,device), 1e3*denorm(y,maxB,minB,device), reduction='sum')


rmse = torch.sqrt(mse_temp/num_samples/grid_space/3)

Rsquare=1-mse_temp/R_temp/num_samples
print(f'Got rmse {rmse}')




return rmse, mse_temp/num_samples/grid_space/3, Rsquare

#-----------------------------------------------------------------


#----------------------------------------------------------------
def grad_loss(preds, y):
'''
preds, y shape: (batch, dimension, grid_x, grid_y, grid_z)
This function computes lamda_g*| nabla(y) - nabla(preds)|
'''
grad_preds = torch.gradient(preds,spacing=1.0)
grad_y = torch.gradient(y, spacing=1)
grad_loss = 0
for i in range(2,5):
# accumulate grad loss for grad_x,y,z
grad_loss += torch.mean(torch.abs(grad_y[i]-grad_preds[i]))/3
return grad_loss

def grad_loss_Jacobain(preds,y):
'''
preds, y shape: (batch, dimension, grid_x, grid_y, grid_z)
This function computes lamda_g*| nabla(y) - nabla(preds)| by Jacobian
'''
Jaco_preds,_ = Jacobian3(preds)
Jaco_y ,_ = Jacobian3(y)

grad_loss = torch.mean(torch.abs(Jaco_preds - Jaco_y))

return grad_loss


def Jacobian3(x):
'''
Jacobian for 3D vector field
-------input----------
x shape: (batch, dimension,grid_x, grid_y, grid_z)
'''

dudx = x[:, 0, 1:, :, :] - x[:, 0, :-1, :, :]
dvdx = x[:, 1, 1:, :, :] - x[:, 1, :-1, :, :]
dwdx = x[:, 2, 1:, :, :] - x[:, 2, :-1, :, :]

dudy = x[:, 0, :, 1:, :] - x[:, 0, :, :-1, :]
dvdy = x[:, 1, :, 1:, :] - x[:, 1, :, :-1, :]
dwdy = x[:, 2, :, 1:, :] - x[:, 2, :, :-1, :]

dudz = x[:, 0, :, :, 1:] - x[:, 0, :, :, :-1]
dvdz = x[:, 1, :, :, 1:] - x[:, 1, :, :, :-1]
dwdz = x[:, 2, :, :, 1:] - x[:, 2, :, :, :-1]

dudx = torch.cat((dudx, torch.unsqueeze(dudx[:,-1],dim=1)), dim=1)
dvdx = torch.cat((dvdx, torch.unsqueeze(dvdx[:,-1],dim=1)), dim=1)
dwdx = torch.cat((dwdx, torch.unsqueeze(dwdx[:,-1],dim=1)), dim=1)

dudy = torch.cat((dudy, torch.unsqueeze(dudy[:,:,-1],dim=2)), dim=2)
dvdy = torch.cat((dvdy, torch.unsqueeze(dvdy[:,:,-1],dim=2)), dim=2)
dwdy = torch.cat((dwdy, torch.unsqueeze(dwdy[:,:,-1],dim=2)), dim=2)

dudz = torch.cat((dudz, torch.unsqueeze(dudz[:,:,:,-1],dim=3)), dim=3)
dvdz = torch.cat((dvdz, torch.unsqueeze(dvdz[:,:,:,-1],dim=3)), dim=3)
dwdz = torch.cat((dwdz, torch.unsqueeze(dwdz[:,:,:,-1],dim=3)), dim=3)

u = dwdy - dvdz
v = dudz - dwdx
w = dvdx - dudy

j = torch.stack([dudx,dudy,dudz,dvdx,dvdy,dvdz,dwdx,dwdy,dwdz],axis=1)
c = torch.stack([u,v,w],axis=1) #vorticity

return j,c

Loading

0 comments on commit 2cc9178

Please sign in to comment.