Skip to content

Commit

Permalink
Enable shap for real runs, disable shap for permutation runs, update …
Browse files Browse the repository at this point in the history
…README
  • Loading branch information
yiming-kang committed Jun 18, 2022
1 parent f1e7908 commit 6386529
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
12 changes: 6 additions & 6 deletions CODE/explain_human_resps.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,11 @@ def parse_args(argv):
parser.add_argument(
'-N', '--number_of_permutations', type=int, nargs='?', const=5, choices=range(1,100),
help="Number of permutation runs (default 5).")
parser.add_argument(
'--enable_permutation_shap', action='store_true',
help="Enable SHAP for permutation runs. Permutations will default to not calculating SHAP values\n(using this flag will make permutations training significantly longer.")
parsed = parser.parse_args(argv[1:])
return parsed



def run_tfpr(tf_feat_mtx_dict, nontf_feat_mtx, features, label_df_dict, output_dir, model_hyparams, model_tuning, permutations, number_of_permutations, disable_shap):
def run_tfpr(tf_feat_mtx_dict, nontf_feat_mtx, features, label_df_dict, output_dir, model_hyparams, model_tuning, permutations, number_of_permutations):

last_run_num = 0

Expand Down Expand Up @@ -83,6 +79,10 @@ def run_tfpr(tf_feat_mtx_dict, nontf_feat_mtx, features, label_df_dict, output_d
logger.info('==> Cross validating response prediction model <==')
tfpr_explainer.cross_validate(permute=permutations)

if not permutations:
logger.info('==> Analyzing feature contributions <==')
tfpr_explainer.explain()

logger.info('==> Saving output data <==')
tfpr_explainer.save()

Expand Down Expand Up @@ -126,7 +126,7 @@ def main(argv):
tf_feat_mtx_dict[feat_info_dict['tfs'][0]].shape,
nontf_feat_mtx.shape))

run_tfpr(tf_feat_mtx_dict, nontf_feat_mtx, features, label_df_dict, filepath_dict['output_dir'], model_hyparams, args.model_tuning, args.permutations, args.number_of_permutations, args.disable_shap)
run_tfpr(tf_feat_mtx_dict, nontf_feat_mtx, features, label_df_dict, filepath_dict['output_dir'], model_hyparams, args.model_tuning, args.permutations, args.number_of_permutations)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion CODE/explain_yeast_resps.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def main(argv):

logger.info('==> Cross validating response prediction model <==')
tfpr_explainer.cross_validate()


logger.info('==> Analyzing feature contributions <==')
tfpr_explainer.explain()

logger.info('==> Saving output data <==')
tfpr_explainer.save()

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ $ python3 CODE/preprocess_yeast_data.py \
--gene_var <pre-pert gex variations (.csv)>
```

For human data, use `CODE/preprocess_human_data.py` with the above set of arguments with the additional `-r` for distal enhancer and promoter data.
For human data, use `CODE/preprocess_human_data.py` with the above set of arguments with the additional `-r <regulatory elements (.bed)>` for distal enhancer and promoter data.

For running random permutation on human data, enable it by setting `--permutations` (default to 5 runs) and optionally set the number of permutations by `--number_of_permutations <n>`.

### Response label

Expand Down

0 comments on commit 6386529

Please sign in to comment.