Skip to content

Commit

Permalink
tune v3
Browse files Browse the repository at this point in the history
  • Loading branch information
JunangWang committed Mar 19, 2024
1 parent 6822a7a commit cece0f9
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 20 deletions.
61 changes: 47 additions & 14 deletions Modeling eMNS/Generative_model_ETH_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
"### Train ETH data to CNN generative network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -U \"ray[data,train,tune,serve]\""
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -90,12 +99,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_44369/275140010.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 128\u001b[0m )\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/tuner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_ray_client\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_local_tuner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTuneError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m raise TuneError(\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/impl/tuner_internal.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 507\u001b[0m \u001b[0mparam_space\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_space\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 508\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_restored\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 509\u001b[0;31m \u001b[0manalysis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_space\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0manalysis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_resume\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_space\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/impl/tuner_internal.py\u001b[0m in \u001b[0;36m_fit_internal\u001b[0;34m(self, trainable, param_space)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tuner_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m }\n\u001b[0;32m--> 628\u001b[0;31m analysis = run(\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 630\u001b[0m )\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/tune.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(run_or_experiment, name, metric, mode, stop, time_budget_s, config, resources_per_trial, num_samples, storage_path, storage_filesystem, search_alg, scheduler, checkpoint_config, verbose, progress_reporter, log_to_file, trial_name_creator, trial_dirname_creator, sync_config, export_formats, max_failures, fail_fast, restore, resume, reuse_actors, raise_on_failed_trial, callbacks, max_concurrent_trials, keep_checkpoints_num, checkpoint_score_attr, checkpoint_freq, checkpoint_at_end, chdir_to_trial_dir, local_dir, _remote, _remote_string_queue, _entrypoint)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1015\u001b[0;31m \u001b[0mrunner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mforce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Trial Runner checkpointing failed: {str(e)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/execution/tune_controller.py\u001b[0m in \u001b[0;36mcheckpoint\u001b[0;34m(self, force, wait)\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_checkpoint_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_checkpoint_enabled\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mforce\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 477\u001b[0m ):\n\u001b[0;32m--> 478\u001b[0;31m self._checkpoint_manager.checkpoint(\n\u001b[0m\u001b[1;32m 479\u001b[0m \u001b[0msave_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_to_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mforce\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 480\u001b[0m )\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/execution/experiment_state.py\u001b[0m in \u001b[0;36mcheckpoint\u001b[0;34m(self, save_fn, force, wait)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;31m# This context will serialize the dataset execution plan instead, if available.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mout_of_band_serialize_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0msave_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;31m# Sync to cloud\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/site-packages/ray/tune/execution/tune_controller.py\u001b[0m in \u001b[0;36msave_to_dir\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtmp_file_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"w\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 370\u001b[0;31m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrunner_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTuneFunctionEncoder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 371\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 372\u001b[0m os.replace(\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/json/__init__.py\u001b[0m in \u001b[0;36mdump\u001b[0;34m(obj, fp, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;31m# could accelerate with writelines in some versions of Python, at\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;31m# a debuggability cost\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode\u001b[0;34m(o, _current_indent_level)\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m_iterencode_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 430\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 431\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m_iterencode_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 432\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmarkers\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode_dict\u001b[0;34m(dct, _current_indent_level)\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 404\u001b[0m \u001b[0mchunks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_iterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 405\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mchunks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 406\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnewline_indent\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 407\u001b[0m \u001b[0m_current_indent_level\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode_list\u001b[0;34m(lst, _current_indent_level)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[0mchunks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_iterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 325\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mchunks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 326\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnewline_indent\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 327\u001b[0m \u001b[0m_current_indent_level\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/vipuser/anaconda3/lib/python3.9/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode_list\u001b[0;34m(lst, _current_indent_level)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseparator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0m_encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'null'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"from Neural_network import Generative_net, Generative_net_test, ResidualEMNSBlock_3d, BigBlock, weight_init, eMNS_Dataset\n",
"from Training_loop_v2 import train_GM\n",
"from Training_loop_v2 import train_GM, train_GM_ray\n",
"from functools import partial\n",
"from ray.train import RunConfig, ScalingConfig, CheckpointConfig\n",
"from ray.train.torch import TorchTrainer\n",
Expand Down Expand Up @@ -124,9 +157,9 @@
" )\n",
"param_space = {\n",
" \"scaling_config\": ScalingConfig(\n",
" num_workers = 1,\n",
" num_workers = 3,\n",
" use_gpu = use_gpu,\n",
" resources_per_worker = {\"CPU\":10, \"GPU\":2}\n",
" # resources_per_worker = {\"CPU\":4, \"GPU\":1}\n",
" ),\n",
" # You can even grid search various datasets in Tune.\n",
" # \"datasets\": {\n",
Expand All @@ -135,13 +168,13 @@
" # ),\n",
" # },\n",
" \"train_loop_config\": {\n",
" 'epochs': tune.choice([350]),\n",
" 'lr_max': tune.loguniform(1e-4,1e-2),\n",
" 'lr_min': tune.loguniform(1e-5,1e-7),\n",
" 'batch_size': tune.choice([4,8,16]),\n",
" 'L2_norm' : tune.choice([0]),\n",
" 'epochs': 350,\n",
" 'lr_max': 1e-4,\n",
" 'lr_min': 2.5e-6,\n",
" 'batch_size': 8,\n",
" 'L2_norm' : 0,\n",
" 'verbose': False,\n",
" 'DF' : tune.choice([True,False]),\n",
" 'DF' : False,\n",
" 'schedule': [],\n",
" 'grid_space': 16**3,\n",
" 'learning_rate_decay': 0.5,\n",
Expand Down Expand Up @@ -251,7 +284,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot_ray_results(results, metrics_names=['rmse_train','rmse_val'],ylim=[20,50])"
"plot_ray_results(results, metrics_names=['rmse_train','rmse_val'],ylim=[3,50])"
]
},
{
Expand Down Expand Up @@ -280,13 +313,13 @@
")\n",
"\n",
"config = {\n",
" 'epochs': 10,\n",
" 'epochs': 350,\n",
" 'lr_max': 1e-4,\n",
" 'lr_min': 2.5e-6,\n",
" 'batch_size': 8,\n",
" 'L2_norm' : 0,\n",
" 'verbose': False,\n",
" 'DF' : True,\n",
" 'DF' : False,\n",
" 'schedule': [],\n",
" 'grid_space': 16**3,\n",
" 'learning_rate_decay': 0.5,\n",
Expand Down
Loading

0 comments on commit cece0f9

Please sign in to comment.