diff --git a/run_data_measurements.py b/run_data_measurements.py index 7724553..07fb75a 100644 --- a/run_data_measurements.py +++ b/run_data_measurements.py @@ -59,7 +59,6 @@ def load_or_prepare(dataset_args, calculation=False, use_cache=False): # TODO: Catch error exceptions for each measurement, so that an error # for one measurement doesn't break the calculation of all of them. - do_all = False dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, use_cache=use_cache) logs.info("Tokenizing dataset.") @@ -67,17 +66,14 @@ def load_or_prepare(dataset_args, calculation=False, use_cache=False): logs.info("Calculating vocab.") dstats.load_or_prepare_vocab() - if not calculation: - do_all = True - - if do_all or calculation == "general": + if calculation == "all" or calculation == "general": logs.info("\n* Calculating general statistics.") dstats.load_or_prepare_general_stats() logs.info("Done!") logs.info( "Basic text statistics now available at %s." % dstats.general_stats_json_fid) - if do_all or calculation == "duplicates": + if calculation == "all" or calculation == "duplicates": logs.info("\n* Calculating text duplicates.") dstats.load_or_prepare_text_duplicates() duplicates_fid_dict = dstats.duplicates_files @@ -85,7 +81,7 @@ def load_or_prepare(dataset_args, calculation=False, use_cache=False): for key, value in duplicates_fid_dict.items(): logs.info("%s: %s" % (key, value)) - if do_all or calculation == "lengths": + if calculation == "all" or calculation == "lengths": logs.info("\n* Calculating text lengths.") dstats.load_or_prepare_text_lengths() length_fid_dict = dstats.length_obj.get_filenames() @@ -94,7 +90,7 @@ def load_or_prepare(dataset_args, calculation=False, use_cache=False): print("%s: %s" % (key, value)) print() - if do_all or calculation == "labels": + if calculation == "all" or calculation == "labels": logs.info("\n* Calculating label statistics.") dstats.load_or_prepare_labels() npmi_fid_dict = dstats.label_files @@ -103,7 +99,20 @@ def load_or_prepare(dataset_args, calculation=False, use_cache=False): print("%s: %s" % (key, value)) print() - if do_all or calculation == "npmi": + + if calculation == "all" or calculation == "zipf": + logs.info("\n* Preparing Zipf.") + dstats.load_or_prepare_zipf() + logs.info("Done!") + zipf_json_fid, zipf_fig_json_fid, zipf_fig_html_fid = zipf.get_zipf_fids( + dstats.dataset_cache_dir) + logs.info("Zipf results now available at %s." % zipf_json_fid) + logs.info( + "Figure saved to %s, with corresponding json at %s." + % (zipf_fig_html_fid, zipf_fig_json_fid) + ) + + if calculation == "all" or calculation == "npmi": print("\n* Preparing nPMI.") dstats.load_or_prepare_npmi() npmi_fid_dict = dstats.npmi_files @@ -117,24 +126,12 @@ def load_or_prepare(dataset_args, calculation=False, use_cache=False): print("%s: %s" % (key, value)) print() - if do_all or calculation == "zipf": - logs.info("\n* Preparing Zipf.") - dstats.load_or_prepare_zipf() - logs.info("Done!") - zipf_json_fid, zipf_fig_json_fid, zipf_fig_html_fid = zipf.get_zipf_fids( - dstats.dataset_cache_dir) - logs.info("Zipf results now available at %s." % zipf_json_fid) - logs.info( - "Figure saved to %s, with corresponding json at %s." - % (zipf_fig_html_fid, zipf_fig_json_fid) - ) - - # Don't do this one until someone specifically asks for it -- takes awhile. + # We removed this from the tool. if calculation == "embeddings": logs.info("\n* Preparing text embeddings.") dstats.load_or_prepare_embeddings() - # Don't do this one until someone specifically asks for it -- takes awhile. + # We removed this from the tool. if calculation == "perplexities": logs.info("\n* Preparing text perplexities.") dstats.load_or_prepare_text_perplexities() @@ -210,6 +207,7 @@ def main(): parser.add_argument( "-w", "--calculation", + default="all", help="""What to calculate (defaults to everything except embeddings and perplexities).\n Options are:\n