Skip to content

Commit

Permalink
Merge branch 'master' of /~https://github.com/scjjb/DRAS-MIL
Browse files Browse the repository at this point in the history
  • Loading branch information
scjjb committed Jan 26, 2023
2 parents a617153 + 85aa900 commit 471dbd1
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 3 deletions.
80 changes: 80 additions & 0 deletions bootstrapping_reps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
import pandas as pd
from sklearn.metrics import roc_auc_score
import numpy as np

def calculate_error(Y_hat, Y):
if Y_hat==Y:
error=0
else:
error=1
return error


parser = argparse.ArgumentParser(description='Model names input split by commas')
parser.add_argument('--model_names', type=str, default=None,help='models to plot')
parser.add_argument('--bootstraps', type=int, default=100000,
help='Number of bootstraps to calculate')
parser.add_argument('--run_repeats', type=int, default=10,
help='Number of model repeats')
parser.add_argument('--folds', type=int, default=10,
help='Number of cross-validation folds')
parser.add_argument('--num_classes',type=int,default=2,help='Number of classes')
args = parser.parse_args()
model_names=args.model_names.split(",")
bootstraps=args.bootstraps

for model_name in model_names:
model_name='eval_results/EVAL_'+model_name
all_Ys=[]
all_p1s=[]
all_probs=[]
all_Yhats=[]
for run_no in range(args.run_repeats):
Ys=[]
probs=[]
p1s=[]
Yhats=[]
for fold_no in range(args.folds):
if args.run_repeats>1:
full_df = pd.read_csv(model_name+'_run{}/fold_{}.csv'.format(run_no,fold_no))
else:
full_df = pd.read_csv(model_name+'/fold_{}.csv'.format(fold_no))
Ys=Ys+list(full_df['Y'])
if args.num_classes==2:
p1s=p1s+list(full_df['p_1'])
else:
if len(all_probs)<1:
probs=full_df.iloc[:,-args.num_classes:]
else:
probs=probs.append(full_df.iloc[:,-args.num_classes:])
Yhats=Yhats+list(full_df['Y_hat'])
all_Ys.append(Ys)
all_probs.append(probs)
all_p1s.append(p1s)
all_Yhats.append(Yhats)

AUC_scores=[]
err_scores=[]
for _ in range(bootstraps):
idxs=np.random.choice(range(len(all_Ys)),len(all_Ys[0]))
sample_Ys=[]
sample_probs=[]
sample_p1s=[]
for i,idx in enumerate(idxs):
sample_Ys=sample_Ys+[all_Ys[idx][i]]
if args.num_classes>2:
sample_probs=sample_probs+[all_probs[idx][i]]
else:
sample_p1s=sample_p1s+[all_p1s[idx][i]]
if args.num_classes>2:
AUC_scores=AUC_scores+[roc_auc_score(sample_Ys,sample_probs,multi_class='ovr')]
else:
AUC_scores=AUC_scores+[roc_auc_score(sample_Ys,sample_p1s)]
error=0

print("AUC mean: ",np.mean(AUC_scores)," AUC std: ",np.std(AUC_scores))




2 changes: 1 addition & 1 deletion other_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
for run_no in range(args.run_repeats):
for fold_no in range(args.folds):
if args.run_repeats>1:
full_df = pd.read_csv(model_name+'_run{}/fold_{}.csv'.format(run_no,fold_no+2))
full_df = pd.read_csv(model_name+'_run{}/fold_{}.csv'.format(run_no,fold_no))
else:
full_df = pd.read_csv(model_name+'/fold_{}.csv'.format(fold_no))
all_Ys=all_Ys+list(full_df['Y'])
Expand Down
73 changes: 73 additions & 0 deletions other_metrics_reps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score,balanced_accuracy_score
import numpy as np
import ast


parser = argparse.ArgumentParser(description='Model names input split by commas')
parser.add_argument('--model_names', type=str, default=None,help='models to plot')
parser.add_argument('--bootstraps', type=int, default=100000,
help='Number of bootstraps to calculate')
parser.add_argument('--run_repeats', type=int, default=10,
help='Number of model repeats')
parser.add_argument('--folds', type=int, default=10,
help='Number of cross-validation folds')
parser.add_argument('--data_csv', type=str, default='set_all_714.csv')
parser.add_argument('--label_dict',type=str,default="{'high_grade':0,'low_grade':1,'clear_cell':2,'endometrioid':3,'mucinous':4}")
parser.add_argument('--num_classes',type=int,default=2)
args = parser.parse_args()
model_names=args.model_names.split(",")
bootstraps=args.bootstraps
label_dict=ast.literal_eval(args.label_dict)

for model_name in model_names:
model_name='eval_results/EVAL_'+model_name
all_Ys=[]
all_p1s=[]
all_Yhats=[]
#all_slides=[]
all_ground_truths=[]
ground_truths=pd.read_csv("dataset_csv/{}".format(args.data_csv))

for run_no in range(args.run_repeats):
Ys=[]
p1s=[]
Yhats=[]
for fold_no in range(args.folds):
if args.run_repeats>1:
full_df = pd.read_csv(model_name+'_run{}/fold_{}.csv'.format(run_no,fold_no))
else:
full_df = pd.read_csv(model_name+'/fold_{}.csv'.format(fold_no))
Ys=Ys+list(full_df['Y'])
p1s=p1s+list(full_df['p_1'])
Yhats=Yhats+list(full_df['Y_hat'])
#all_slides=all_slides+list(full_df['slide_id'])
all_Ys.append(Ys)
all_p1s.append(p1s)
all_Yhats.append(Yhats)

f1s=[]
accuracies=[]
balanced_accuracies=[]
for _ in range(bootstraps):
idxs=np.random.choice(range(len(all_Ys)),len(all_Ys[0]))
sample_Ys=[]
sample_p1s=[]
sample_Yhats=[]
for i,idx in enumerate(idxs):
sample_Ys=sample_Ys+[all_Ys[idx][i]]
sample_p1s=sample_p1s+[all_p1s[idx][i]]
sample_Yhats=sample_Yhats+[all_Yhats[idx][i]]
if args.num_classes==2:
f1s=f1s+[f1_score(sample_Ys,sample_Yhats)]
else:
f1s=f1s+[f1_score(sample_Ys,sample_Yhats,average='macro')]
accuracies=accuracies+[accuracy_score(sample_Ys,sample_Yhats)]
balanced_accuracies=balanced_accuracies+[balanced_accuracy_score(sample_Ys,sample_Yhats)]
if args.num_classes==2:
print("F1 mean: ",np.mean(f1s)," F1 std: ",np.std(f1s))
else:
print("Macro F1 mean: ",np.mean(f1s)," F1 std: ",np.std(f1s))
print("accuracy mean: ",np.mean(accuracies)," accuracy std: ",np.std(accuracies))
print("balanced accuracy mean: ",np.mean(balanced_accuracies)," balanced accuracy std: ",np.std(balanced_accuracies))
4 changes: 2 additions & 2 deletions utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def eval(config, dataset, args, ckpt_path):
if args.tuning:
args.weight_smoothing=config["weight_smoothing"]
args.resampling_iterations=config["resampling_iterations"]
args.samples_per_iteration=int(1200/(config["resampling_iterations"]))
args.samples_per_iteration=int(640/(config["resampling_iterations"]))
args.sampling_neighbors=config["sampling_neighbors"]
args.sampling_random=config["sampling_random"]
args.sampling_random_delta=config["sampling_random_delta"]
Expand Down Expand Up @@ -474,7 +474,7 @@ def summary_sampling(model, dataset, args):

#all_probs[(batch_idx*same_slide_repeats)+repeat_no] = probs
all_probs.append(probs[0])
all_labels_byrep.append(label)
all_labels_byrep.append(label[0].item())
all_preds[(batch_idx*same_slide_repeats)+repeat_no] = Y_hat.item()

if args.plot_sampling:
Expand Down

0 comments on commit 471dbd1

Please sign in to comment.