Skip to content

Commit

Permalink
merge deterministic RL branch
Browse files Browse the repository at this point in the history
  • Loading branch information
chirath committed Oct 30, 2024
2 parents 22a9620 + d168f98 commit db3b5d6
Show file tree
Hide file tree
Showing 26 changed files with 6,046 additions and 214 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ __pycache__/
*.py[cod]
*.pyc
results/
experiments/nci/
.env
experiments/nci/help.txt
experiments/nci/help.txt
341 changes: 341 additions & 0 deletions agents/ddpg/core.py

Large diffs are not rendered by default.

420 changes: 420 additions & 0 deletions agents/ddpg/ddpg.py

Large diffs are not rendered by default.

380 changes: 380 additions & 0 deletions agents/ddpg/models.py

Large diffs are not rendered by default.

74 changes: 74 additions & 0 deletions agents/ddpg/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
def set_args(args):
args.action_type = 'exponential' # 'normal', quadratic, proportional_quadratic, exponential
args.feature_history = 12 #24
args.calibration = 12 #24
args.action_scale = 5 # 5/factor
args.insulin_max = 5
# these params adjust the pump.
args.expert_bolus = False
args.expert_cf = False
args.n_features = 2
args.t_meal = 20
args.use_meal_announcement = False # adds meal announcement as a timeseries feature.
args.use_carb_announcement = False
args.use_tod_announcement = False
args.use_handcraft = 0
args.n_handcrafted_features = 1
args.n_hidden = 16 # 128
args.n_rnn_layers = 1
args.rnn_directions = 1
args.bidirectional = False
args.rnn_only = True # RNN + 1 Dense layer, deprecated and archietcure fixed
args.max_epi_length = 288 * 10
args.n_step = 256
args.max_test_epi_len = 288
args.gamma = 0.997
# Parameters above this line are kept fixed! for consistency between other RL algorithms.

# parameters important to SAC algo
args.entropy_coef = 0.001 # 0.001 seems to work
args.batch_size = 256 if args.debug == 0 else 64 # the mini_batch size
args.replay_buffer_size = 100000 if args.debug == 0 else 1024 # total <s,a,r,s'> pairs 100000
args.sample_size = 4096 if args.debug == 0 else 128 #256

args.sac_v2 = True

# 200 worked out, 400 runnning

args.shuffle_rollout = True
args.n_training_workers = 16 if args.debug == 0 else 2
args.n_testing_workers = 20 if args.debug == 0 else 2
args.n_pi_epochs = 5 # can be used to increase number of epochs for all networks updates.
# args.pi_lr = 1e-4 * 3 # 1e-4 * 3
# args.vf_lr = 1e-4 * 3 # 1e-4 * 3
args.grad_clip = 20

### todo: refctaor - unused below
args.eps_clip = 0.1 # 0.05 #0.1 # (Usually small, 0.1 to 0.3.) 0.2
args.target_kl = 0.01 # 0.005 #0.01 # (Usually small, 0.01 or 0.05.)
args.normalize_reward = True
args.reward_lr = 1 * 1e-3
args.aux_lr = 1e-4 * 3
args.n_vf_epochs = 1 # FIXED
args.aux_batch_size = 1024

# (2) => aux model learning
args.n_aux_epochs = 5
args.aux_frequency = 1 # frequency of updates
args.aux_vf_coef = 0.01 #10 #1 #
args.aux_pi_coef = 0.01 #10 #1 #
# (3) = > plannning
#args.planning_coef = 1
args.planning_lr = 1e-4 * 3
args.kl = 1
args.use_planning = False #if args.planning_coef == -1 else True
args.planning_n_step = 6
args.plan_type = 4
args.n_planning_simulations = 50
args.plan_batch_size = 1024
args.n_plan_epochs = 1
# clean up below: deprecated
args.bgp_pred_mode = False
args.n_bgp_steps = 0 # todo: this is fixed, need to be changed manually -> fix

return args
137 changes: 137 additions & 0 deletions agents/ddpg/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import csv
import gym
import torch
import itertools
import numpy as np
import pandas as pd
from copy import deepcopy
from collections import deque
from utils.pumpAction import Pump
from utils.core import get_env, time_in_range, custom_reward, combined_shape
from agents.ddpg.core import Memory, StateSpace, composite_reward
from agents.std_bb.BBController import BasalBolusController
from utils.carb_counting import carb_estimate


class Worker:
def __init__(self, args, mode, patients, env_ids, seed, worker_id, device):
self.worker_id = worker_id
self.worker_mode = mode
self.args = args
self.device = device
self.episode = 0
self.update_timestep = args.n_step
self.max_test_epi_len = args.max_test_epi_len
self.max_epi_length = args.max_epi_length
self.calibration = args.calibration
self.simulation_seed = seed + 100
self.patient_name = patients[args.patient_id]
self.env_id = str(worker_id) + '_' + env_ids[args.patient_id]
self.env = get_env(self.args, patient_name=self.patient_name, env_id=self.env_id,
custom_reward=custom_reward, seed=self.simulation_seed)
self.state_space = StateSpace(self.args)
self.pump = Pump(self.args, patient_name=self.patient_name)
self.std_basal = self.pump.get_basal()
self.memory = Memory(self.args, device)
self.episode_history = np.zeros(combined_shape(self.max_epi_length, 14), dtype=np.float32)
self.reinit_flag = False
self.init_env()
self.log1_columns = ['epi', 't', 'cgm', 'meal', 'ins', 'rew', 'rl_ins', 'mu', 'sigma',
'prob', 'state_val', 'day_hour', 'day_min', 'IS']
self.log2_columns = ['epi', 't', 'reward', 'normo', 'hypo', 'sev_hypo', 'hyper', 'lgbi',
'hgbi', 'ri', 'sev_hyper', 'aBGP_rmse', 'cBGP_rmse']
self.save_log([self.log1_columns], '/'+self.worker_mode+'/data/logs_worker_')
self.save_log([self.log2_columns], '/'+self.worker_mode+'/data/'+self.worker_mode+'_episode_summary_')

def init_env(self):
if not self.reinit_flag:
self.episode += 1
self.counter = 0
self.init_state = self.env.reset()
self.cur_state, self.feat = self.state_space.update(cgm=self.init_state.CGM, ins=0, meal=0)
self.cgm_hist = deque(self.calibration * [0], self.calibration)
self.ins_hist = deque(self.calibration * [0], self.calibration)
self.cgm_hist.append(self.init_state.CGM)
self.pump.calibrate(self.init_state)
self.calibration_process()

def calibration_process(self):
self.reinit_flag, cur_cgm = False, 0
for t in range(0, self.calibration): # open-loop simulation for calibration period.
state, reward, is_done, info = self.env.step(self.std_basal)
cur_cgm = state.CGM
self.cgm_hist.append(state.CGM)
self.ins_hist.append(self.std_basal)
self.cur_state, self.feat = self.state_space.update(cgm=state.CGM, ins=self.std_basal,
meal=info['remaining_time'], hour=self.counter,
meal_type=info['meal_type']) #info['day_hour']
self.reinit_flag = True if info['meal_type'] != 0 else False # meal_type zero -> no meal
if (cur_cgm < 110 or 130 < cur_cgm) and self.worker_mode != 'training': # checking simulation start within normo
self.reinit_flag = True
if self.reinit_flag:
self.init_env()

def rollout(self, ddpg, replay_memory):
ri, alive_steps, normo, hypo, sev_hypo, hyper, lgbi, hgbi, sev_hyper = 0, 0, 0, 0, 0, 0, 0, 0, 0
if self.worker_mode != 'training': # fresh env for testing
self.init_env()
rollout_steps = self.update_timestep if self.worker_mode == 'training' else self.max_test_epi_len

for n_steps in range(0, rollout_steps):
policy_step, mu, sigma = ddpg.get_action(self.cur_state, self.feat, worker_mode=self.worker_mode)
selected_action = policy_step[0]
rl_action, pump_action = self.pump.action(agent_action=selected_action, prev_state=self.init_state, prev_info=None)
state, reward, is_done, info = self.env.step(pump_action)
reward = composite_reward(self.args, state=state.CGM, reward=reward)
this_state = deepcopy(self.cur_state)
this_feat = deepcopy(self.feat)
done_flag = 1 if state.CGM <= 40 or state.CGM >= 600 else 0
# update -> state.
self.cur_state, self.feat = self.state_space.update(cgm=state.CGM, ins=pump_action,
meal=info['remaining_time'], hour=(self.counter+1),
meal_type=info['meal_type'], carbs=info['future_carb'])

if self.worker_mode == 'training':
replay_memory.push(torch.as_tensor(this_state, dtype=torch.float32, device=self.device).unsqueeze(0),
torch.as_tensor(this_feat, dtype=torch.float32, device=self.device).unsqueeze(0),
torch.as_tensor([selected_action], dtype=torch.float32, device=self.device),
torch.as_tensor([reward], dtype=torch.float32, device=self.device),
torch.as_tensor(self.cur_state, dtype=torch.float32, device=self.device).unsqueeze(0),
torch.as_tensor(self.feat, dtype=torch.float32, device=self.device).unsqueeze(0),
torch.as_tensor([done_flag], dtype=torch.float32, device=self.device))


# store -> rollout for training
# if self.worker_mode == 'training':
# self.memory.store(this_state, this_feat, selected_action, reward, self.cur_state, self.feat, done_flag)

self.episode_history[self.counter] = [self.episode, self.counter, state.CGM, info['meal'] * info['sample_time'],
pump_action, reward, rl_action, mu[0], sigma[0], 0, 0, info['day_hour'],
info['day_min'], 0]
self.counter += 1
stop_factor = (self.max_epi_length - 1) if self.worker_mode == 'training' else (self.max_test_epi_len - 1)
criteria = state.CGM <= 40 or state.CGM >= 600 or self.counter > stop_factor # training or state.CGM >= 400
if criteria: # episode termination criteria.
df = pd.DataFrame(self.episode_history[0:self.counter], columns=self.log1_columns)
df.to_csv(self.args.experiment_dir + '/' + self.worker_mode + '/data/logs_worker_' + str(self.worker_id) + '.csv',
mode='a', header=False, index=False)
alive_steps = self.counter
normo, hypo, sev_hypo, hyper, lgbi, hgbi, ri, sev_hyper = time_in_range(df['cgm'], df['meal'], df['ins'],
self.episode, self.counter, display=False)
self.save_log([[self.episode, self.counter, df['rew'].sum(), normo, hypo, sev_hypo, hyper, lgbi,
hgbi, ri, sev_hyper, 0, 0]],
'/' + self.worker_mode + '/data/' + self.worker_mode + '_episode_summary_')

if self.worker_mode == 'training':
self.init_env()
else:
break # stop rollout if this is a testing worker!
data = [ri, alive_steps, normo, hypo, sev_hypo, hyper, lgbi, hgbi, sev_hyper]
return data

def save_log(self, log_name, file_name):
with open(self.args.experiment_dir + file_name + str(self.worker_id) + '.csv', 'a+') as f:
csvWriter = csv.writer(f, delimiter=',')
csvWriter.writerows(log_name)
f.close()

Loading

0 comments on commit db3b5d6

Please sign in to comment.