diff --git a/Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..6fa661e2 --- /dev/null +++ b/Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"DenoiSeg_2D_ZeroCostDL4Micv2.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"IkSguVy8Xv83"},"source":["# **Image denoising and segmentation using DenoiSeg 2D**\n","\n","---\n","\n"," DenoiSeg 2D is deep-learning method that can be used to jointly denoise and segment 2D microscopy images. By running this notebook, you can train your and use you own network. \n","\n"," The benefits of using DenoiSeg (compared to other Deep Learning-based segmentation methods) are more prononced when only a few annotated images are available. However, the denoising part requires many images to perform well. All the noisy images don't need to be labeled to train DenoiSeg.\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper: **DenoiSeg: Joint Denoising and Segmentation**\n","Tim-Oliver Buchholz, Mangal Prakash, Alexander Krull, Florian Jug\n","https://arxiv.org/abs/2005.02987\n","\n","And source code found in: /~https://github.com/juglab/DenoiSeg/wiki\n","\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","**it needs to have access to a paired training dataset made of images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**Importantly, the benefits of using DenoiSeg are more pronounced when only limited numbers of segmentation annotations are available for training. However, DenoiSeg also expects that lots of noisy raw images are available to train the denoising part. It is therefore not required for all the noisy images to be annotated to train DenoiSeg**.\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Noisy Images (Training_source)\n"," - img_1.tif, img_2.tif, img_3.tif, img_4.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif\n"," - **Quality control dataset (optional, not required for training)**\n"," - Noisy Images\n"," - img_1.tif, img_2.tif\n"," - High SNR Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"h5i5CS2bSmZr","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"01Djr8v-5pPk","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n4yWFoJNnoin"},"source":["# **2. Install DenoiSeg and Dependencies**\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fq21zJVFNASx","colab":{}},"source":["#@markdown ##Install DenoiSeg and dependencies\n","!pip install q keras==2.2.5\n","\n","# Here we enable Tensorflow 1. \n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install denoiseg\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to Denoiseg -------\n","\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","import numpy as np\n","from matplotlib import pyplot as plt\n","from scipy import ndimage\n","\n","from denoiseg.models import DenoiSeg, DenoiSegConfig\n","from denoiseg.utils.misc_utils import combine_train_test_data, shuffle_train_data, augment_data\n","from denoiseg.utils.seg_utils import *\n","from denoiseg.utils.compute_precision_threshold import measure_precision, compute_labels\n","\n","from csbdeep.utils import plot_history\n","from tifffile import imread, imsave\n","from glob import glob\n","\n","import urllib\n","import os\n","import zipfile\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`Priority`:** Choose how much relative the importance to assign to the denoising \n","and segmentation tasks by choosing an appropriate value (between 0 and 1; with 0 being only segmentation and 1 being only denoising. **Default value: 0.5**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: depends on number of patches, min 100; max 400**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"ewpNJ_I0Mv47","colab":{}},"source":["# create DataGenerator-object.\n","\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Path to validation image(s): \n","#Validation_source = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/DenoiSeg/Test - Noisy\" #@param {type:\"string\"}\n","#Validation_target = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/DenoiSeg/Test - Masks\" #@param {type:\"string\"}\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 10#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","Priority = 0.5#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","batch_size = 128#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," Priority = 0.5\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(R + \"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_target))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","# Here we count the number of files in the training target folder\n","Mask_Filelist = os.listdir(Training_target)\n","Mask_number_files = len(Mask_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Mask_for_validation = int((Mask_number_files)/percentage_validation)\n","\n","if Mask_for_validation == 0:\n"," Mask_for_validation = 2\n","if Mask_for_validation == 1:\n"," Mask_for_validation = 2\n","\n","# Here we count the number of files in the training target folder\n","Noisy_Filelist = os.listdir(Training_source)\n","Noisy_number_files = len(Noisy_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n","\n","if Noisy_for_validation == 0:\n"," Noisy_for_validation = 1\n","\n","#Here we find the noisy images that do not have masks\n","noisy_image_no_mask_list = list(set(Noisy_Filelist) - set(Mask_Filelist))\n","\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n","\n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","#Here we move images to be used for validation\n","for i in range(Mask_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+list_target[i], Validation_source_temp+\"/\"+list_target[i])\n"," shutil.move(Training_target_temp+\"/\"+list_target[i], Validation_target_temp+\"/\"+list_target[i])\n","\n","#Here we move a few more noisy images for validation\n","if noisy_image_no_mask_list:\n"," for y in range(Noisy_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+noisy_image_no_mask_list[y], Validation_source_temp+\"/\"+noisy_image_no_mask_list[y])\n","\n","\n","print(\"Parameters initiated.\")\n","\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.title('Training target')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis (multiply the dataset by 8). \n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," \n","\n","\n"," "]},{"cell_type":"code","metadata":{"id":"VipPCXmwL1YN","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a DenoiSeg model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/DenoiSeg/Results/test_denoiSeg_3\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"rBelu-LtbOTh","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","print(\"In progress...\")\n","\n","Training_source_dir = Training_source_temp\n","Training_target_dir = Training_target_temp\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","validation_images_tiff=Validation_source_temp+\"/*.tif\"\n","validation_mask_tiff=Validation_target_temp+\"/*.tif\"\n","\n","train_images = imread(sorted(glob(training_images_tiff)))\n","val_images = imread(sorted(glob(validation_images_tiff)))\n","\n","available_train_masks = imread(sorted(glob(mask_images_tiff)))\n","available_val_masks = imread(sorted(glob(validation_mask_tiff)))\n","\n","#This allows the users to not have all their training images segmented\n","blank_images_train = np.zeros((train_images.shape[0]-available_train_masks.shape[0], available_train_masks.shape[1], available_train_masks.shape[2]))\n","blank_images_val = np.zeros((val_images.shape[0]-available_val_masks.shape[0], available_val_masks.shape[1], available_val_masks.shape[2]))\n","blank_images_train = blank_images_train.astype(\"uint16\")\n","blank_images_val = blank_images_val.astype(\"uint16\")\n","\n","train_masks = np.concatenate((available_train_masks,blank_images_train), axis = 0)\n","val_masks = np.concatenate((available_val_masks,blank_images_val), axis = 0)\n","\n","\n","if not Use_Data_augmentation:\n"," X, Y_train_masks = train_images, train_masks\n","\n","# Now we apply data augmentation to the training patches:\n","# Rotate four times by 90 degree and add flipped versions.\n","if Use_Data_augmentation:\n"," X, Y_train_masks = augment_data(train_images, train_masks)\n","\n","X_val, Y_val_masks = val_images, val_masks\n","\n","# Here we add the channel dimension to our input images.\n","# Dimensionality for training has to be 'SYXC' (Sample, Y-Dimension, X-Dimension, Channel)\n","X = X[...,np.newaxis]\n","Y = convert_to_oneHot(Y_train_masks)\n","X_val = X_val[...,np.newaxis]\n","Y_val = convert_to_oneHot(Y_val_masks)\n","print(\"Shape of X: {}\".format(X.shape))\n","print(\"Shape of Y: {}\".format(Y.shape))\n","print(\"Shape of X_val: {}\".format(X_val.shape))\n","print(\"Shape of Y_val: {}\".format(Y_val.shape))\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= max(100, min(int(X.shape[0]/batch_size), 400))\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","\n","config = DenoiSegConfig(X, unet_kern_size=3, n_channel_out=4, relative_weights = [1.0,1.0,5.0],\n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," batch_norm=True, train_batch_size=batch_size, unet_n_first = 32, \n"," unet_n_depth=4, denoiseg_alpha=Priority, train_learning_rate = initial_learning_rate, train_tensorboard=False)\n","\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","\n","model = DenoiSeg(config=config, name=model_name, basedir=model_path)\n","\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (DenoiSeg -- DenoiSeg Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fisJmA13Mv5e","scrolled":true,"colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start Training\n","%memit\n","\n","\n","\n","history = model.train(X, Y, (X_val, Y_val))\n","\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","threshold, val_score = model.optimize_thresholds(val_images[:available_val_masks.shape[0]].astype(np.float32), val_masks, measure=measure_precision())\n","\n","print(\"The higest score of {} is achieved with threshold = {}.\".format(np.round(val_score, 3), threshold))\n","\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate','threshold'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i], str(threshold)])\n","\n","#Thresholdpath = model_path+'/'+model_name+'/Quality Control/optimal_threshold.csv'\n","#with open(Thresholdpath, 'w') as f1:\n"," #writer1 = csv.writer(f1)\n"," #writer1.writerow(['threshold'])\n"," #writer1.writerow([str(threshold)])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='DenoiSeg', \n"," description='DenoiSeg 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"vMzSP50kMv5p","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","**DenoiSeg** allow to both denoise and segment microscopy images. This section allow you to evaluate both tasks separetly.\n","\n","**Evaluation of the denoising**\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_Denoising_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","**Evaluation of the Segmentation**\n","\n","This option will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_Segmentation_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose what to evaluate\n","\n","Evaluate_Denoising = True #@param {type:\"boolean\"}\n","\n","Evaluate_Segmentation = True #@param {type:\"boolean\"}\n","\n","\n","# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/DenoiSeg/Test - Noisy\" #@param{type:\"string\"}\n","Target_Denoising_folder = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/DenoiSeg/Test - GT Images\" #@param{type:\"string\"}\n","Target_Segmentation_folder = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/DenoiSeg/Test - Masks\" #@param{type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# ------------- Prepare the model and run predictions ------------\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Source_QC_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_segmentation_\"+base_filename, segmented_images)\n","\n","# ------------- Here we Start assessing the denoising against GT ------------\n","\n","if Evaluate_Denoising:\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_Denoising_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_denoised_\"+i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n"," Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n"," norm = simple_norm(x, percent = 99)\n","\n"," plt.figure(figsize=(15,15))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(Target_Denoising_folder, Test_FileList[-1]))\n"," plt.imshow(img_GT, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n"," plt.imshow(img_Source, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", \"Predicted_denoised_\"+Test_FileList[-1]))\n"," plt.imshow(img_Prediction, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","#________________________________________________________________________\n","# Here we start testing the differences between GT and predicted masks\n","\n","if Evaluate_Segmentation:\n","\n","\n","\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control_Segmentation for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_segmentation_\"+n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_Segmentation_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","#Display the last image\n","\n"," f = plt.figure(figsize=(25,25))\n","\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(test_input, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n","\n","#Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_0_to_255, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n","#Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_prediction_0_to_255, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n","#Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_0_to_255, cmap='Greens')\n"," plt.imshow(test_prediction_0_to_255, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"bl3EdYFVS7X9","colab":{}},"source":["import imageio\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/Stardist/Test - Images\" #@param {type:\"string\"}\n","Result_folder = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/Results\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"/content/gdrive/My Drive/Work/manuscript/Ongoing Projects/Zero-Cost Deep-Learning to Enhance Microscopy/test folder/Training datasets/Results/test_denoiSeg_2\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a Threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","print(\"Processing...\")\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(Result_folder+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(Result_folder+\"/\"+\"Predicted_segmentation_\"+base_filename,segmented_images)\n"," \n","\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"jFp-0y4zT_gL","colab":{}},"source":["\n","# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+\"Predicted_denoised_\"+random_choice)\n","z = imread(Result_folder+\"/\"+\"Predicted_segmentation_\"+random_choice)\n","\n","norm = simple_norm(x, percent = 99)\n","\n","plt.figure(figsize=(30,15))\n","plt.subplot(1, 4, 1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off');\n","plt.title(\"Input\")\n","\n","plt.subplot(1, 4, 2)\n","plt.imshow(y, interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off');\n","plt.title(\"Predicted denoised image\")\n","\n","plt.subplot(1, 4, 3)\n","plt.imshow(z, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.axis('off');\n","plt.title(\"Predicted segmentation\")\n","\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS","colab_type":"text"},"source":["#**Thank you for using DenoiSeg!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 index 347b0ae4..9955eeaf --- a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (2D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D dataset. If you are interested in restoring 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: /~https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 100** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # in pixels\n","number_of_patches = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 400#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_-CEUqlS8o3M","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"qe9zvEJ9qOH2","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"zmtlu9YU266X","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 1 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4kb3xSZMRzxU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"mlN-VNOgR-nr","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"LKYRNhA5Qnis","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"biXiR017C4UU","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (2D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D dataset. If you are interested in restoring 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: /~https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 100** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # in pixels\n","number_of_patches = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 400#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_-CEUqlS8o3M","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"qe9zvEJ9qOH2","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"zmtlu9YU266X","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 1 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4kb3xSZMRzxU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"mlN-VNOgR-nr","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"LKYRNhA5Qnis","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"biXiR017C4UU","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 index a7ee7342..58600e39 --- a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: /~https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 40#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","\n","\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","colab_type":"code","cellView":"form","colab":{}},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","%memit \n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w8Q_uYGgiico","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Result_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J","colab_type":"text"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: /~https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","\n","\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","colab_type":"code","cellView":"form","colab":{}},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","%memit \n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w8Q_uYGgiico","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Result_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J","colab_type":"text"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/ChangeLog.txt b/Colab_notebooks/ChangeLog.txt old mode 100755 new mode 100644 index 47f0b572..b128bf1d --- a/Colab_notebooks/ChangeLog.txt +++ b/Colab_notebooks/ChangeLog.txt @@ -7,6 +7,28 @@ Latest releases available here: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/releases +————————————————————————————————————————————————————————— +ZeroCostDL4Mic v1.10 + + +Major changes: + +- New beta notebook : DenoiSeg 2D + +- StarDist 2D, StarDist 3D, CARE 2D and CARE 3D notebooks now back to TensorFlow 1.5 (instead of TF 2.2, issues with TF 2.3) + +- Deep-STORM now runs on TensorFlow 2.3 + +- Models trained using StarDist 2D, CARE 2D, CARE 3D, DenoiSeg 2D, Noise2Void 2D and 3D notebooks can be used in Fiji via their respective plugin + + +————————————————————————————————————————————————————————— +ZeroCostDL4Mic v1.9 + +Minor aesthetic bug fixes (titles and section naming mostly). + + + ————————————————————————————————————————————————————————— ZeroCostDL4Mic v1.8 diff --git a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 diff --git a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 index cb0b72d1..5986637d --- a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4","colab_type":"text"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: /~https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.__version__ != '2.2.0':\n"," !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ","colab_type":"text"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Deep-STORM and dependencies\n","\n","# %% Model definition + helper functions\n","\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos","colab_type":"text"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l","colab_type":"text"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Load raw data\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9","colab_type":"text"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data. It is not necesary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY","colab_type":"text"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx","colab_type":"text"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n","colab_type":"text"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA","colab_type":"text"},"source":["## **4.4. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CHVTRjEOLRDH","colab_type":"text"},"source":["##**4.5. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, \"QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n","\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ","colab_type":"text"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r","colab_type":"text"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the visualization\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n","colab_type":"text"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV","colab_type":"text"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4","colab_type":"text"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: /~https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","# if tf.__version__ != '2.2.0':\n","# !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ","colab_type":"text"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Deep-STORM and dependencies\n","\n","# %% Model definition + helper functions\n","\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos","colab_type":"text"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l","colab_type":"text"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Load raw data\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9","colab_type":"text"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data. It is not necesary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY","colab_type":"text"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx","colab_type":"text"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n","colab_type":"text"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA","colab_type":"text"},"source":["## **4.4. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CHVTRjEOLRDH","colab_type":"text"},"source":["##**4.5. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, \"QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n","\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ","colab_type":"text"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r","colab_type":"text"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the visualization\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n","colab_type":"text"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV","colab_type":"text"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/LICENSE.txt b/Colab_notebooks/LICENSE.txt old mode 100755 new mode 100644 diff --git a/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb index 78b9874b..7361d58f 100644 --- a/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"IkSguVy8Xv83"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"h5i5CS2bSmZr","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"01Djr8v-5pPk","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fq21zJVFNASx","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","# Here we enable Tensorflow 1.\n","!pip install q keras==2.2.5\n","\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"ewpNJ_I0Mv47","colab":{}},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(R + \"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"-Vy-vV7ssabS","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"rBelu-LtbOTh","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(X.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fisJmA13Mv5e","scrolled":true,"colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"vMzSP50kMv5p","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"bl3EdYFVS7X9","colab":{}},"source":["#Activate the pretrained model. \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","print(\"Saving the images...\")\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier.\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"jFp-0y4zT_gL","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output')\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS","colab_type":"text"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"IkSguVy8Xv83"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"h5i5CS2bSmZr","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"01Djr8v-5pPk","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fq21zJVFNASx","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","# Here we enable Tensorflow 1.\n","!pip install q keras==2.2.5\n","\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"ewpNJ_I0Mv47","colab":{}},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(R + \"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"-Vy-vV7ssabS","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"rBelu-LtbOTh","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(X.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (N2V -- N2V Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fisJmA13Mv5e","scrolled":true,"colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"vMzSP50kMv5p","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"bl3EdYFVS7X9","colab":{}},"source":["#Activate the pretrained model. \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","print(\"Saving the images...\")\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier.\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"jFp-0y4zT_gL","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output')\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS","colab_type":"text"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb index e0849fd9..34383250 100644 --- a/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83","colab_type":"text"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","!pip install q keras==2.2.5\n","\n","# Enable the Tensorflow 1 instead of the Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M","colab_type":"text"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"scrolled":true,"colab_type":"code","cellView":"form","id":"iwNmp1PUzRDQ","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nRaaG02xZh_N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","colab_type":"code","cellView":"form","colab":{}},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t","colab_type":"text"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83","colab_type":"text"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","!pip install q keras==2.2.5\n","\n","# Enable the Tensorflow 1 instead of the Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M","colab_type":"text"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"scrolled":true,"colab_type":"code","cellView":"form","id":"iwNmp1PUzRDQ","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 3D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='ZYX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nRaaG02xZh_N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","colab_type":"code","cellView":"form","colab":{}},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t","colab_type":"text"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/ReadMe.txt b/Colab_notebooks/ReadMe.txt old mode 100755 new mode 100644 diff --git a/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb index a34f5fca..8dc6cd64 100644 --- a/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WAfQW1Mj3wy1XQZZUfU4DJVS_R_E8Cn3","timestamp":1585665697353},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"/~https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","!pip install q keras==2.2.5\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"KWpu5p8utpE2","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 2 #@param {type:\"number\"}\n","number_of_steps = 20#@param {type:\"number\"}\n","patch_size = 1024 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vgT0NU3P6Bwt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"8in3wzAw6G6g","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"2zk1H8J06aJH","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"x4zMG4lMths-","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"SfQeukJJtv9u","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1) # normalize channels independently\n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["\n","## **4.2. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"o2O0QnO4PFlz","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-2b4RMU_Ec2y","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"KG8wZrA3Ef4n","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GFJBwr5TEgcq","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"EvCMiYaeElc4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","#Display the last image\n","\n","f = plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255, aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, cmap='Greens')\n","plt.imshow(test_prediction_0_to_255, alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iAPmwlxCEzxQ","colab_type":"text"},"source":["# **6. Using the trained model**\n","---"]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Stacks #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," \n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," \n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(os.path.join(outputdir, name_no_extension[num]), polygons) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WAfQW1Mj3wy1XQZZUfU4DJVS_R_E8Cn3","timestamp":1585665697353},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"/~https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","%tensorflow_version 1.x\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"KWpu5p8utpE2","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 2 #@param {type:\"number\"}\n","number_of_steps = 20#@param {type:\"number\"}\n","patch_size = 1024 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vgT0NU3P6Bwt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"8in3wzAw6G6g","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"2zk1H8J06aJH","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"x4zMG4lMths-","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"SfQeukJJtv9u","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1) # normalize channels independently\n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["\n","## **4.2. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"o2O0QnO4PFlz","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-2b4RMU_Ec2y","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"KG8wZrA3Ef4n","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GFJBwr5TEgcq","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"EvCMiYaeElc4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","#Display the last image\n","\n","f = plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255, aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, cmap='Greens')\n","plt.imshow(test_prediction_0_to_255, alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iAPmwlxCEzxQ","colab_type":"text"},"source":["# **6. Using the trained model**\n","---"]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Stacks #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," \n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," \n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(os.path.join(outputdir, name_no_extension[num]), polygons) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb index d36ffe5d..cce718c2 100644 --- a/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1Ur-4VIQ6gf4ONupD6hK0M-AcJkoTzMlU","timestamp":1586789439593},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"/~https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools\n","!pip install edt\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"nAW3oU60htR_","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 400 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 400**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 1#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","patch_size = 64#@param {type:\"number\"} # pixels in\n","patch_height = 64#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters): \n"," patch_size = min(Image_Y, Image_X) \n"," patch_height = Image_Z\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nbyf-RevQhDL","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UQ2hultWQlT9","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis as well as performing elastic deformations\n","\n","**The flip option and the elastic deformation will double the size of your dataset, rotation will quadruple and all together will increase the dataset by a factor of 16.**\n","\n"," Elastic deformations performed by [Elasticdeform.](https://elasticdeform.readthedocs.io/en/latest/index.html).\n"]},{"cell_type":"code","metadata":{"id":"wYdTY6ULg01b","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###See Elasticdeform’s license\n","#Copyright (c) 2001, 2002 Enthought, Inc. All rights reserved.\n","\n","#Copyright (c) 2003-2017 SciPy Developers. All rights reserved.\n","\n","#Copyright (c) 2018 Gijs van Tulder. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","##Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","#Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","print(\"Double click to see elasticdeform’s license\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kKLB47jgQrxr","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown **Deform your images**\n","\n","Elastic_deformation = True #@param {type:\"boolean\"}\n","\n","Deformation_Sigma = 3 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","Save_augmented_images = True #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","\n","\n","if Use_Data_augmentation:\n","\n","\n"," if Elastic_deformation:\n"," !pip install elasticdeform\n"," import numpy, imageio, elasticdeform\n","\n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n","\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n"," Training_source_augmented = Augmented_folder+\"/Training_source\"\n"," os.makedirs(Training_source_augmented)\n"," Training_target_augmented = Augmented_folder+\"/Training_target\"\n"," os.makedirs(Training_target_augmented)\n"," print(\"Data augmentation enabled\")\n"," print(\"Generation of the augmented dataset in progress\")\n","\n"," if Elastic_deformation:\n"," for filename in os.listdir(Training_source):\n"," X = imread(os.path.join(Training_source, filename))\n"," Y = imread(os.path.join(Training_target, filename))\n"," [X_deformed, Y_deformed] = elasticdeform.deform_random_grid([X, Y], sigma=Deformation_Sigma, order=0)\n","\n"," os.chdir(Augmented_folder+\"/Training_source\")\n"," imsave(filename, X)\n"," imsave(filename+\"_deformed.tif\", X_deformed)\n","\n"," os.chdir(Augmented_folder+\"/Training_target\")\n"," imsave(filename, Y)\n"," imsave(filename+\"_deformed.tif\", Y_deformed)\n","\n"," Training_source_rot = Training_source_augmented\n"," Training_target_rot = Training_target_augmented\n"," \n"," if not Elastic_deformation:\n"," Training_source_rot = Training_source\n"," Training_target_rot = Training_target\n","\n"," \n"," if Rotation == True:\n"," rotation_aug(Training_source_rot,Training_target_rot,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_rot,Training_target_rot)\n","\n"," print(\"Done\")\n","\n"," if Elastic_deformation:\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(x, percent = 99)\n","\n"," random_choice=random.choice(os.listdir(Training_source))\n"," x = imread(Augmented_folder+\"/Training_source/\"+random_choice)\n"," x_deformed = imread(Augmented_folder+\"/Training_source/\"+random_choice+\"_deformed.tif\")\n"," y = imread(Augmented_folder+\"/Training_target/\"+random_choice)\n"," y_deformed = imread(Augmented_folder+\"/Training_target/\"+random_choice+\"_deformed.tif\") \n","\n"," Image_Z = x.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," f=plt.figure(figsize=(10,10))\n"," plt.subplot(2,2,1)\n"," plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Training source (single Z plane)');\n"," plt.subplot(2,2,2)\n"," plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Training target (single Z plane)');\n"," plt.subplot(2,2,3)\n"," plt.imshow(x_deformed[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Deformed training source (single Z plane)');\n"," plt.subplot(2,2,4)\n"," plt.imshow(y_deformed[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Deformed training target (single Z plane)');\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n","\n","\n","\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pjz-5bRVh1ja","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"zeSUtd2Thw-O","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"LqH54fYhdbXU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RzAHUsi-78Ak","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w3Z7Jkv8bPvq","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"05dbg6UrGunj","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mBkuXf5zhHUd","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"i9ek_kIHhK1R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","Image_Z = test_input.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Display the last image\n","\n","f=plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n","plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["from PIL import Image\n","\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 2#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," # One example image \n","print(\"One example image is displayed bellow:\")\n","plt.figure(figsize=(13,10))\n","z = max(0, img.shape[0] // 2 - 5)\n","plt.subplot(121)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.title('Raw image (XY slice)')\n","plt.axis('off')\n","plt.subplot(122)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n","plt.title('Image and predicted labels (XY slice)')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1Ur-4VIQ6gf4ONupD6hK0M-AcJkoTzMlU","timestamp":1586789439593},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"/~https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 3D dataset. If you are interested in 2D dataset, you should use the StarDist 2D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools\n","!pip install edt\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"nAW3oU60htR_","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 400 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 400**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 1#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","patch_size = 64#@param {type:\"number\"} # pixels in\n","patch_height = 64#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters): \n"," patch_size = min(Image_Y, Image_X) \n"," patch_height = Image_Z\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nbyf-RevQhDL","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UQ2hultWQlT9","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis as well as performing elastic deformations\n","\n","**The flip option and the elastic deformation will double the size of your dataset, rotation will quadruple and all together will increase the dataset by a factor of 16.**\n","\n"," Elastic deformations performed by [Elasticdeform.](https://elasticdeform.readthedocs.io/en/latest/index.html).\n"]},{"cell_type":"code","metadata":{"id":"wYdTY6ULg01b","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###See Elasticdeform’s license\n","#Copyright (c) 2001, 2002 Enthought, Inc. All rights reserved.\n","\n","#Copyright (c) 2003-2017 SciPy Developers. All rights reserved.\n","\n","#Copyright (c) 2018 Gijs van Tulder. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","##Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","#Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","print(\"Double click to see elasticdeform’s license\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kKLB47jgQrxr","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown **Deform your images**\n","\n","Elastic_deformation = True #@param {type:\"boolean\"}\n","\n","Deformation_Sigma = 3 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","Save_augmented_images = True #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","\n","\n","if Use_Data_augmentation:\n","\n","\n"," if Elastic_deformation:\n"," !pip install elasticdeform\n"," import numpy, imageio, elasticdeform\n","\n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n","\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n"," Training_source_augmented = Augmented_folder+\"/Training_source\"\n"," os.makedirs(Training_source_augmented)\n"," Training_target_augmented = Augmented_folder+\"/Training_target\"\n"," os.makedirs(Training_target_augmented)\n"," print(\"Data augmentation enabled\")\n"," print(\"Generation of the augmented dataset in progress\")\n","\n"," if Elastic_deformation:\n"," for filename in os.listdir(Training_source):\n"," X = imread(os.path.join(Training_source, filename))\n"," Y = imread(os.path.join(Training_target, filename))\n"," [X_deformed, Y_deformed] = elasticdeform.deform_random_grid([X, Y], sigma=Deformation_Sigma, order=0)\n","\n"," os.chdir(Augmented_folder+\"/Training_source\")\n"," imsave(filename, X)\n"," imsave(filename+\"_deformed.tif\", X_deformed)\n","\n"," os.chdir(Augmented_folder+\"/Training_target\")\n"," imsave(filename, Y)\n"," imsave(filename+\"_deformed.tif\", Y_deformed)\n","\n"," Training_source_rot = Training_source_augmented\n"," Training_target_rot = Training_target_augmented\n"," \n"," if not Elastic_deformation:\n"," Training_source_rot = Training_source\n"," Training_target_rot = Training_target\n","\n"," \n"," if Rotation == True:\n"," rotation_aug(Training_source_rot,Training_target_rot,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_rot,Training_target_rot)\n","\n"," print(\"Done\")\n","\n"," if Elastic_deformation:\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(x, percent = 99)\n","\n"," random_choice=random.choice(os.listdir(Training_source))\n"," x = imread(Augmented_folder+\"/Training_source/\"+random_choice)\n"," x_deformed = imread(Augmented_folder+\"/Training_source/\"+random_choice+\"_deformed.tif\")\n"," y = imread(Augmented_folder+\"/Training_target/\"+random_choice)\n"," y_deformed = imread(Augmented_folder+\"/Training_target/\"+random_choice+\"_deformed.tif\") \n","\n"," Image_Z = x.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," f=plt.figure(figsize=(10,10))\n"," plt.subplot(2,2,1)\n"," plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Training source (single Z plane)');\n"," plt.subplot(2,2,2)\n"," plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Training target (single Z plane)');\n"," plt.subplot(2,2,3)\n"," plt.imshow(x_deformed[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Deformed training source (single Z plane)');\n"," plt.subplot(2,2,4)\n"," plt.imshow(y_deformed[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Deformed training target (single Z plane)');\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n","\n","\n","\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pjz-5bRVh1ja","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"zeSUtd2Thw-O","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"LqH54fYhdbXU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RzAHUsi-78Ak","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w3Z7Jkv8bPvq","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"05dbg6UrGunj","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mBkuXf5zhHUd","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"i9ek_kIHhK1R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","Image_Z = test_input.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Display the last image\n","\n","f=plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n","plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["from PIL import Image\n","\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," # One example image \n","print(\"One example image is displayed bellow:\")\n","plt.figure(figsize=(13,10))\n","z = max(0, img.shape[0] // 2 - 5)\n","plt.subplot(121)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.title('Raw image (XY slice)')\n","plt.axis('off')\n","plt.subplot(122)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n","plt.title('Image and predicted labels (XY slice)')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 diff --git a/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 index 786dcaf3..c65e8b5d --- a/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb @@ -53,7 +53,7 @@ "\n", "[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n", "\n", - "The following three Python libraries play an important role in the notebook: \n", + "The following two Python libraries play an important role in the notebook: \n", "\n", "1. [**Elasticdeform**](/~https://github.com/gvtulder/elasticdeform)\n", " by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n", @@ -1302,13 +1302,13 @@ "\n", "#@markdown ###Data augmentation\n", "\n", - "apply_data_augmentation = False #@param {type:\"boolean\"}\n", + "apply_data_augmentation = True #@param {type:\"boolean\"}\n", "\n", "# List of augmentations\n", "augmentations = []\n", "\n", "#@markdown ###Gaussian blur\n", - "add_gaussian_blur = False #@param {type:\"boolean\"}\n", + "add_gaussian_blur = True #@param {type:\"boolean\"}\n", "gaussian_sigma = 0.7#@param {type:\"number\"}\n", "gaussian_frequency = 0.5 #@param {type:\"number\"}\n", "\n", @@ -1316,7 +1316,7 @@ " augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n", "\n", "#@markdown ###Linear contrast\n", - "add_linear_contrast = False #@param {type:\"boolean\"}\n", + "add_linear_contrast = True #@param {type:\"boolean\"}\n", "contrast_min = 0.4 #@param {type:\"number\"}\n", "contrast_max = 1.6#@param {type:\"number\"}\n", "contrast_frequency = 0.5 #@param {type:\"number\"}\n", @@ -1335,24 +1335,22 @@ "\n", "#@markdown ###Add custom augmenters\n", "\n", - "augmenters = \"\" #@param {type:\"string\"}\n", + "augmenters = \"GammaContrast; AverageBlur; LinearContrast\" #@param {type:\"string\"}\n", "\n", - "augmenter_params = \"\" #@param {type:\"string\"}\n", + "augmenter_params = \"(0.5, 2.0); (0.5, 2.0); (0.4, 1.6)\" #@param {type:\"string\"}\n", "\n", - "augmenter_frequency = \"\" #@param {type:\"string\"}\n", + "augmenter_frequency = \"0.3; 0.4; 0.5\" #@param {type:\"string\"}\n", "\n", - "if len(augmenters) > 0 and len(augmenter_params) > 0 and len(augmenter_frequency) > 0:\n", - " aug_lst = augmenters.split(';')\n", - " aug_params_lst = augmenter_params.split(';')\n", - " aug_freq_lst = augmenter_frequency.split(';')\n", + "aug_lst = augmenters.split(';')\n", + "aug_params_lst = augmenter_params.split(';')\n", + "aug_freq_lst = augmenter_frequency.split(';')\n", "\n", - " assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n", + "assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n", "\n", - " for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n", - " aug, param, freq = aug.strip(), param.strip(), freq.strip() \n", - " aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n", - " if apply_data_augmentation:\n", - " augmentations.append(aug_func)\n", + "for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n", + " aug, param, freq = aug.strip(), param.strip(), freq.strip() \n", + " aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n", + " augmentations.append(aug_func)\n", "\n", "#@markdown ###Elastic deformations\n", "add_elastic_deform = True #@param {type:\"boolean\"}\n", @@ -1388,10 +1386,7 @@ " downscale=downscaling_in_xy,\n", " binary_target=binary_target)\n", "\n", - "if apply_data_augmentation:\n", - " sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n", - "else:\n", - " sample_src_aug, sample_tgt_aug = train_generator.__getitem__(random.randint(0, len(train_generator)))\n", + "sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n", "\n", "def scroll_in_z(z):\n", " f=plt.figure(figsize=(16,8))\n", @@ -1499,8 +1494,8 @@ "metadata": { "id": "y_DtHgr-41K0", "colab_type": "code", - "cellView": "form", - "colab": {} + "colab": {}, + "cellView": "form" }, "source": [ "#@markdown ##Download model directory\n", @@ -1510,7 +1505,7 @@ "\n", "from google.colab import files\n", "\n", - "model_path_download = \"\" #@param {type:\"string\"}\n", + "model_path_download = \"/content/gdrive/My Drive/Crick/ZeroCostDL4Mic/Eva Masters Project/test_newest_notebook_run1\" #@param {type:\"string\"}\n", "\n", "if len(model_path_download) == 0:\n", " model_path_download = full_model_path\n", @@ -1550,7 +1545,11 @@ "cellView": "form", "colab_type": "code", "id": "EdcnkCr9Nbl8", - "colab": {} + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "1088857c-2d43-4018-95dd-77f4ca24a576" }, "source": [ "#@markdown ###Model to be evaluated:\n", @@ -1574,7 +1573,15 @@ " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" ], "execution_count": null, - "outputs": [] + "outputs": [ + { + "output_type": "stream", + "text": [ + "test_newest_notebook_run1 will be evaluated\n" + ], + "name": "stdout" + } + ] }, { "cell_type": "markdown", @@ -1939,7 +1946,7 @@ "\n", "# Tifffile library issues means that images cannot be appended to \n", "#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n", - "big_tiff = False #@param {type:\"boolean\"}\n", + "big_tiff = True #@param {type:\"boolean\"}\n", "\n", "#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n", "\n", diff --git a/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 index e98d0c25..f7d4d452 --- a/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1VcTsLOL28ntbr23gYrhY3upxkztZeUvn","timestamp":1591024690909},{"file_id":"19jT_GoHGN-UTM1aEgkgrOjB8pcFz5AW4","timestamp":1591017297795},{"file_id":"1UkoWB27ZWh5j_qivSZIOeOJP1h2EqrVz","timestamp":1589363183397},{"file_id":"1ofNqOc7lz-m6NL4B-m4BIheaU5N0GMln","timestamp":1588873191434},{"file_id":"1rJnsgIKyL6vuneydIfjCKMtMhV3XlQ6o","timestamp":1588583580765},{"file_id":"1RUYrp8beEgDKL1kOWw5LgR1QQb4yHQtG","timestamp":1587061416704},{"file_id":"1FVax0eY3-m8DbJHx0B8Dnep-uGlp30Zt","timestamp":1586601038120},{"file_id":"1TTqmCf2mFQ_PNIZEXX9sRAhoixjYP_AB","timestamp":1585842446113},{"file_id":"1cWwS-jbLYTDOpPp_hhKOLGFXfu06ccpG","timestamp":1585821375983},{"file_id":"1TPEE_AtGTLedawgVBwwXofEJEcJUCgo3","timestamp":1585137343783},{"file_id":"1SxFRb38aC_kmKzKVQfkwWzkK9n7YFxVv","timestamp":1585053829456},{"file_id":"15iw9IOwHNF_GhiHxkh_rWbJG8JnW14Wh","timestamp":1584375074441},{"file_id":"15oMbXnMa4LDEMhPHBr3ga0xhJomMLhDo","timestamp":1584105762670},{"file_id":"1__NtYFNA3DxNB7LrUY13Bt8_frye3iWl","timestamp":1583445015203},{"file_id":"11jsQfqKeDU1Zk3nPykjWKwYhFmvJ1zJ-","timestamp":1575289898486}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"WDrFAwpFIpE0","colab_type":"text"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: /~https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"ABNu2p4stHeB","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"HVwncY_NvlYi","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"JrGNzgEyxzGQ","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"wYoajeT54sQM","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"TpT6gbwURzrV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"quzkzlRD45HF","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"eLwDxBnp4-bc","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"leK5kmgD5Ism","colab_type":"text"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"vOeLpQfT0QF1","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","#We enforce the keras==2.2.5 release to ensure that the notebook continues working even if keras is updated.\n","\n","!pip install keras==2.2.5\n","!pip install data\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n","\n"," all_patches_img = np.empty([0,patch_width, patch_height])\n"," all_patches_mask = np.empty([0,patch_width, patch_height])\n","\n"," for file in os.listdir(Training_source):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n"," #the shape of patches_img and patches_mask will be (number of patches along x, number of patches along y,patch_width,patch_height)\n","\n"," all_patches_img = np.concatenate((all_patches_img, patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)), axis = 0)\n"," all_patches_mask = np.concatenate((all_patches_mask, patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)), axis = 0)\n","\n"," number_of_patches = all_patches_img.shape[0]\n"," print('number of patches: '+str(number_of_patches))\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(number_of_patches):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(i)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(i)+'.tif')\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(all_patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*0.98))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_image.tif', img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_mask.tif', convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# def predictionGenerator(Data_path, target_size = (256,256), as_gray = True):\n","# for filename in os.listdir(Data_path):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = io.imread(os.path.join(Data_path, filename), as_gray = as_gray)\n","# img = normalizePercentile(img)\n","# # img = img/255 # WARNING: this is expecting 8bit images\n","# img = transform.resize(img,target_size, preserve_range=True, anti_aliasing=True, order = 1) # liner interpolation\n","# img = np.reshape(img,img.shape+(1,))\n","# img = np.reshape(img,(1,)+img.shape)\n","# yield img\n","\n","\n","# def predictionResize(Data_path, predictions):\n","# resized_predictions = []\n","# for (i, filename) in enumerate(os.listdir(Data_path)):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = Image.open(os.path.join(Data_path, filename))\n","# (width, height) = img.size\n","# resized_predictions.append(transform.resize(predictions[i], (height, width), preserve_range=True, anti_aliasing=True, order = 1))\n","# return resized_predictions\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," \n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7hTKImff6Est","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"S74FbqV6PNNv","colab_type":"text"},"source":["##**3.1. Parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"3np5EpJF8_q2","colab_type":"text"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n","**Make sure the directories exist before entering them!**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. Smaller values can be used for testing. **Default: 6**\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n"]},{"cell_type":"code","metadata":{"id":"7deNuPZd5d-B","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 6#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height)\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V9UCjlLJ5Rfc","colab_type":"text"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"i-PahNX94-pl","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7vFEIHbNAuOs","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"RfR9UyKAAulw","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"94FX4wzE8w1W","colab_type":"text"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"tlTDGcmDDHDe","colab_type":"text"},"source":["## **4.1. Prepare model for training**\n","---"]},{"cell_type":"code","metadata":{"id":"ezFy_mpz_op4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","\n","# n = 0 \n","# while os.path.isdir(os.path.join(Training_source, source_images[n])):\n","# n += 1\n","\n","# (width, height) = Image.open(os.path.join(Training_target, source_images[n])).size\n","# ImageSize = (height, width) # np.shape different from PIL image.size return !\n","\n","# !!! WARNING !!! Check potential issues with resizing at the ImageDataGenerator level\n","# (train_datagen, validation_datagen) = prepareGenerators(Training_source, Training_target, data_gen_args, batch_size, target_size = ImageSize)\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we make sure this is properly defined\n","if not Use_pretrained_model:\n"," h5_file_path = None\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","# Dfine CSV logger that will create the loss file (we're not using this anylonger)\n","# csv_log = CSVLogger(os.path.join(full_model_path, 'Quality Control', 'training_evaluation.csv'), separator=',', append=False)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","\n","if Use_Default_Advanced_Parameters:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"urpQ9UM-6NBE","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","\n","####**Be patient**. Please be patient, this may take a while. But the verbose allow you to estimate how fast it's training and how long it'll take. While it's training, please make sure that the computer is not powering down due to inactivity, otherwise this will interupt the runtime."]},{"cell_type":"code","metadata":{"id":"sMyCENd29TKz","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LWaFk0JNda-N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"mEMcFNHZdmTz","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"X11zGW0Ldu-z","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pkJyRzWJCrKG","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qul6BpaX1GqS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"h33P0C2geqZu","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"Tpqjvwv2zug-","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('Threshold value')\n","plt.xlabel('IoU')\n","plt.legend()\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n","\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gofmRsLP96O8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Pv_v1Ru2OJkU","colab_type":"text"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"FJAe55ZoOJGs","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"su-Mo2POVpja","colab_type":"text"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"iC_B_9lxNUny","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wYmwCQKjYsJ7","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"sCXzzvnh2_rc","colab_type":"text"},"source":["#**Thank you for using U-Net!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1VcTsLOL28ntbr23gYrhY3upxkztZeUvn","timestamp":1591024690909},{"file_id":"19jT_GoHGN-UTM1aEgkgrOjB8pcFz5AW4","timestamp":1591017297795},{"file_id":"1UkoWB27ZWh5j_qivSZIOeOJP1h2EqrVz","timestamp":1589363183397},{"file_id":"1ofNqOc7lz-m6NL4B-m4BIheaU5N0GMln","timestamp":1588873191434},{"file_id":"1rJnsgIKyL6vuneydIfjCKMtMhV3XlQ6o","timestamp":1588583580765},{"file_id":"1RUYrp8beEgDKL1kOWw5LgR1QQb4yHQtG","timestamp":1587061416704},{"file_id":"1FVax0eY3-m8DbJHx0B8Dnep-uGlp30Zt","timestamp":1586601038120},{"file_id":"1TTqmCf2mFQ_PNIZEXX9sRAhoixjYP_AB","timestamp":1585842446113},{"file_id":"1cWwS-jbLYTDOpPp_hhKOLGFXfu06ccpG","timestamp":1585821375983},{"file_id":"1TPEE_AtGTLedawgVBwwXofEJEcJUCgo3","timestamp":1585137343783},{"file_id":"1SxFRb38aC_kmKzKVQfkwWzkK9n7YFxVv","timestamp":1585053829456},{"file_id":"15iw9IOwHNF_GhiHxkh_rWbJG8JnW14Wh","timestamp":1584375074441},{"file_id":"15oMbXnMa4LDEMhPHBr3ga0xhJomMLhDo","timestamp":1584105762670},{"file_id":"1__NtYFNA3DxNB7LrUY13Bt8_frye3iWl","timestamp":1583445015203},{"file_id":"11jsQfqKeDU1Zk3nPykjWKwYhFmvJ1zJ-","timestamp":1575289898486}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"WDrFAwpFIpE0","colab_type":"text"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: /~https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"ABNu2p4stHeB","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"HVwncY_NvlYi","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"JrGNzgEyxzGQ","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"wYoajeT54sQM","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"TpT6gbwURzrV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"quzkzlRD45HF","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"eLwDxBnp4-bc","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"leK5kmgD5Ism","colab_type":"text"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"vOeLpQfT0QF1","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","# %tensorflow_version 1.x\n","# import tensorflow\n","# print(tensorflow.__version__)\n","# print(\"Tensorflow enabled.\")\n","\n","#We enforce the keras==2.2.5 release to ensure that the notebook continues working even if keras is updated.\n","\n","!pip install keras==2.2.5\n","!pip install data\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n","\n"," all_patches_img = np.empty([0,patch_width, patch_height])\n"," all_patches_mask = np.empty([0,patch_width, patch_height])\n","\n"," for file in os.listdir(Training_source):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n"," #the shape of patches_img and patches_mask will be (number of patches along x, number of patches along y,patch_width,patch_height)\n","\n"," all_patches_img = np.concatenate((all_patches_img, patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)), axis = 0)\n"," all_patches_mask = np.concatenate((all_patches_mask, patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)), axis = 0)\n","\n"," number_of_patches = all_patches_img.shape[0]\n"," print('number of patches: '+str(number_of_patches))\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(number_of_patches):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(i)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(i)+'.tif')\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(all_patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*0.98))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_image.tif', img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_mask.tif', convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# def predictionGenerator(Data_path, target_size = (256,256), as_gray = True):\n","# for filename in os.listdir(Data_path):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = io.imread(os.path.join(Data_path, filename), as_gray = as_gray)\n","# img = normalizePercentile(img)\n","# # img = img/255 # WARNING: this is expecting 8bit images\n","# img = transform.resize(img,target_size, preserve_range=True, anti_aliasing=True, order = 1) # liner interpolation\n","# img = np.reshape(img,img.shape+(1,))\n","# img = np.reshape(img,(1,)+img.shape)\n","# yield img\n","\n","\n","# def predictionResize(Data_path, predictions):\n","# resized_predictions = []\n","# for (i, filename) in enumerate(os.listdir(Data_path)):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = Image.open(os.path.join(Data_path, filename))\n","# (width, height) = img.size\n","# resized_predictions.append(transform.resize(predictions[i], (height, width), preserve_range=True, anti_aliasing=True, order = 1))\n","# return resized_predictions\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," \n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7hTKImff6Est","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"S74FbqV6PNNv","colab_type":"text"},"source":["##**3.1. Parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"3np5EpJF8_q2","colab_type":"text"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n","**Make sure the directories exist before entering them!**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. Smaller values can be used for testing. **Default: 6**\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n"]},{"cell_type":"code","metadata":{"id":"7deNuPZd5d-B","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 6#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height)\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V9UCjlLJ5Rfc","colab_type":"text"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"i-PahNX94-pl","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7vFEIHbNAuOs","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"RfR9UyKAAulw","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"94FX4wzE8w1W","colab_type":"text"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"tlTDGcmDDHDe","colab_type":"text"},"source":["## **4.1. Prepare model for training**\n","---"]},{"cell_type":"code","metadata":{"id":"ezFy_mpz_op4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","\n","# n = 0 \n","# while os.path.isdir(os.path.join(Training_source, source_images[n])):\n","# n += 1\n","\n","# (width, height) = Image.open(os.path.join(Training_target, source_images[n])).size\n","# ImageSize = (height, width) # np.shape different from PIL image.size return !\n","\n","# !!! WARNING !!! Check potential issues with resizing at the ImageDataGenerator level\n","# (train_datagen, validation_datagen) = prepareGenerators(Training_source, Training_target, data_gen_args, batch_size, target_size = ImageSize)\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we make sure this is properly defined\n","if not Use_pretrained_model:\n"," h5_file_path = None\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","# Dfine CSV logger that will create the loss file (we're not using this anylonger)\n","# csv_log = CSVLogger(os.path.join(full_model_path, 'Quality Control', 'training_evaluation.csv'), separator=',', append=False)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","\n","if Use_Default_Advanced_Parameters:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"urpQ9UM-6NBE","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","\n","####**Be patient**. Please be patient, this may take a while. But the verbose allow you to estimate how fast it's training and how long it'll take. While it's training, please make sure that the computer is not powering down due to inactivity, otherwise this will interupt the runtime."]},{"cell_type":"code","metadata":{"id":"sMyCENd29TKz","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LWaFk0JNda-N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"mEMcFNHZdmTz","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"X11zGW0Ldu-z","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pkJyRzWJCrKG","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qul6BpaX1GqS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"h33P0C2geqZu","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"Tpqjvwv2zug-","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('Threshold value')\n","plt.xlabel('IoU')\n","plt.legend()\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n","\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gofmRsLP96O8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Pv_v1Ru2OJkU","colab_type":"text"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"FJAe55ZoOJGs","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"su-Mo2POVpja","colab_type":"text"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"iC_B_9lxNUny","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wYmwCQKjYsJ7","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"sCXzzvnh2_rc","colab_type":"text"},"source":["#**Thank you for using U-Net!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 diff --git a/Colab_notebooks/ZeroCostDL4Mic_UserManual_v1.2.pdf b/Colab_notebooks/ZeroCostDL4Mic_UserManual_v1.2.pdf old mode 100755 new mode 100644 diff --git a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644 index 7995fce1..b8d685f1 --- a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1G6lQzjd259Yoy_OozBhJolF4HraE52PG","timestamp":1591353884724},{"file_id":"1pSC680miQesRinU8Tjn7X6AmJXNtxUNI","timestamp":1591182507229},{"file_id":"1ajYZgvhQfpcUZ5YWlUeB-GUU_j-njsqw","timestamp":1589209398121},{"file_id":"1QiFrHg_cVlOl_yzu-RO9mMIrA2L1dXwj","timestamp":1587744376035},{"file_id":"1_S3UtNcuAaZhVc4yqlFDHc2eKq1x-ynn","timestamp":1587058075616},{"file_id":"1Gce_llcAX7yJTFZP2HiNpTL56gXR7PQ-","timestamp":1586854238074},{"file_id":"10l0NA5VWlqRvDlJRTxOiOUgN5LxEo2gy","timestamp":1586601464429},{"file_id":"1NSdad2BEDJZ16AO3SEEaG-ZSe0o4u3eY","timestamp":1586368373257},{"file_id":"1ubiSLYW3G4eNGNF31e2Vbw_3jMHJ9Y7M","timestamp":1585303720184},{"file_id":"1O6YzESEk9VFr6Nc6ijOAYCtiP80uuh7I","timestamp":1585248652537},{"file_id":"1DPrSIbf-ML-LIO2e4YhL1KedWVsVcFlT","timestamp":1585232236512},{"file_id":"1Qanbeybd44tHmdzKxTJAMDD4trFdCYwD","timestamp":1585049767771},{"file_id":"1Fr9Ea5QdUgK0CKfQKpq9KrxtxxAkSVwc","timestamp":1584619265981},{"file_id":"1RQ6XuOBIRaWgId2WKO2i-MMnXoKn_tNA","timestamp":1584541702239},{"file_id":"1mAvQKCCelwK8zPkAWFvKtiAsE_35KSpW","timestamp":1584533728194},{"file_id":"1LdMzIh-v-gUXnd6v9U2Ov28T-XpeT1PP","timestamp":1584463518766},{"file_id":"18Y0NabtThelB0uOAJlg7UbjHPYMEoCqW","timestamp":1584455459923},{"file_id":"1ZCnLW6HUl0bXrPa-54-bv_C9f6jYL0T4","timestamp":1584436296801},{"file_id":"1gTLXTd_rOpXmlktZz2yeEW62gY8ety-I","timestamp":1583941948440},{"file_id":"1gC_pmaDD73tD-yNoFGjHEolYfLd_7czL","timestamp":1583593255888},{"file_id":"17pZee2Vp0kCh3W8pfzRYk8asqk35mOfw","timestamp":1583335080677},{"file_id":"1KyYm3JglQpPYnf-aBLLiP-sFgi_A0Og1","timestamp":1583291424450},{"file_id":"1ZJCI2p66noTaLCnVUQJkTR16ig6GAqAx","timestamp":1576151149296}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"C-wdtVN5KUFi","colab_type":"text"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: /~https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"Qt5Yt1vsD163","colab_type":"text"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"zwILBhMkzKp_","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n","\n"," This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.\n"," The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB."]},{"cell_type":"markdown","metadata":{"id":"pcNfrIVpNZC-","colab_type":"text"},"source":["---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"I0aF5U_Y0IFW","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"EBHobPtQ8wx7","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"UphYcwdDS8yO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kRVmtCZB9OQ2","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"QTEFQc6j9RTv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yk96o-_u-27d","colab_type":"text"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"BbYpGlfskzrO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","import datetime\n","import time\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch /~https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqCe6m-C_PrH","colab_type":"text"},"source":["#**3. Select your paths and parameters**\n","---"]},{"cell_type":"markdown","metadata":{"id":"w5NmDpJ4xvWE","colab_type":"text"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"PWxNzzgKu9Kb","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,len(source)//10)\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCKcSJxkxi33","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"msrTTcPI1Cav","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"u_YFN6Bd594L","colab_type":"code","cellView":"form","colab":{}},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," \n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," if len(source)>130:\n"," number_of_images = 130\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"heuBzM5JADYf","colab_type":"text"},"source":["#**4. Train the network**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"eLllOs_rA62U","colab_type":"text"},"source":["##**4.1. Start Trainning**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3)."]},{"cell_type":"code","metadata":{"id":"xe3TLu7M-3Dk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","#Overwriting old models and saving them separately if True\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!./scripts/train_model.sh $model_name 0\n","\n","#After training overwrite any existing model in the model_path with the new trained model.\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fpXr4JlCd5uV","colab_type":"text"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"id":"x41OhmO-hsX3","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\n","\n","import shutil\n","import os\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","else:\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QefQX9WUBz0G","colab_type":"text"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"2-0m_-tF9oo-","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we replace values in the old files\n","\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","#model_name = \"\" #@param {type:\"string\"}\n","\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","batch_size = 4 #@param {type:\"number\"}\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#Move your model to fnet\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\n","\n","#Move the datasets into fnet\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","### number_of_images = len(os.listdir(Training_source)) ###\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","\n","# We will use the same validation files from the training dataset as used before,\n","# This makes sure that the model is not validated with files it has seen in training before saving.\n","\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\n","val_source_list = []\n","\n","##CHECK THIS Prediction_model_name\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_folder+'_val.csv', 'r') as f:\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\n"," contents = csv.reader(f,delimiter=',')\n"," for row in contents:\n"," val_source_list.append(row[0])\n","\n","#Get the file list without the header\n","val_source_list = val_source_list[1::]\n","\n","#Get only the file names and not the full path\n","for i in range(0,len(val_source_list)):\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\n","\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_source_list:\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","#Make a training csv file.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\n","\n","#Find the number of previous training iterations (steps) from loss csv file\n","\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 50000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vH3EzxbfD6Uk","colab_type":"code","cellView":"form","colab":{}},"source":["import datetime\n","import time\n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","import os\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\n","\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","if os.path.exists(Pretrained_model_folder):\n"," shutil.rmtree(Pretrained_model_folder)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\n","\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","min, sec = divmod(dt, 60) \n","hour, min = divmod(min, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jwORXPtcqRHZ","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"rVBx2b2MpoFf","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"/content/gdrive/My Drive/NewFnet_2\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aNR6bAk6oZJD","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ratRdSDlcQ9G","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YkhOGv3Hp2xI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"vqSH6EQb4BwU","colab_type":"code","cellView":"form","colab":{}},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"/content/gdrive/My Drive/Label-free_prediction_(fnet)_v2/Test_dataset/Test-Transmitted_light_stacks_Split_data\" #@param{type:\"string\"}\n","Target_QC_folder = \"/content/gdrive/My Drive/Label-free_prediction_(fnet)_v2/Test_dataset/Test-TOM20_fluorescence_stacks_Split_data\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Source_QC_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model == True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V2ghLobACMy6","colab_type":"text"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"SMw0nWXeeC1N","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"8yoXStc8Lo27","colab_type":"code","cellView":"form","colab":{}},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e2f-coEkCf58","colab_type":"text"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"Uzv5rp6LrYQF","colab_type":"code","cellView":"form","colab":{}},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the slice would you like to view?\n","slice_number = 1#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3dP2CrCVee1m","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"IXXOocFl3on8","colab_type":"text"},"source":["## **6.4. Purge unnecessary folders**\n","---\n"]},{"cell_type":"code","metadata":{"id":"emO85anSThPJ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\n","\n","import shutil\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"l52zLRCn3z9v","colab_type":"text"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1-rjE9xp9Jrjkti3DT_bEDEhxx3kgvhEb","timestamp":1597394632737},{"file_id":"1G6lQzjd259Yoy_OozBhJolF4HraE52PG","timestamp":1591353884724},{"file_id":"1pSC680miQesRinU8Tjn7X6AmJXNtxUNI","timestamp":1591182507229},{"file_id":"1ajYZgvhQfpcUZ5YWlUeB-GUU_j-njsqw","timestamp":1589209398121},{"file_id":"1QiFrHg_cVlOl_yzu-RO9mMIrA2L1dXwj","timestamp":1587744376035},{"file_id":"1_S3UtNcuAaZhVc4yqlFDHc2eKq1x-ynn","timestamp":1587058075616},{"file_id":"1Gce_llcAX7yJTFZP2HiNpTL56gXR7PQ-","timestamp":1586854238074},{"file_id":"10l0NA5VWlqRvDlJRTxOiOUgN5LxEo2gy","timestamp":1586601464429},{"file_id":"1NSdad2BEDJZ16AO3SEEaG-ZSe0o4u3eY","timestamp":1586368373257},{"file_id":"1ubiSLYW3G4eNGNF31e2Vbw_3jMHJ9Y7M","timestamp":1585303720184},{"file_id":"1O6YzESEk9VFr6Nc6ijOAYCtiP80uuh7I","timestamp":1585248652537},{"file_id":"1DPrSIbf-ML-LIO2e4YhL1KedWVsVcFlT","timestamp":1585232236512},{"file_id":"1Qanbeybd44tHmdzKxTJAMDD4trFdCYwD","timestamp":1585049767771},{"file_id":"1Fr9Ea5QdUgK0CKfQKpq9KrxtxxAkSVwc","timestamp":1584619265981},{"file_id":"1RQ6XuOBIRaWgId2WKO2i-MMnXoKn_tNA","timestamp":1584541702239},{"file_id":"1mAvQKCCelwK8zPkAWFvKtiAsE_35KSpW","timestamp":1584533728194},{"file_id":"1LdMzIh-v-gUXnd6v9U2Ov28T-XpeT1PP","timestamp":1584463518766},{"file_id":"18Y0NabtThelB0uOAJlg7UbjHPYMEoCqW","timestamp":1584455459923},{"file_id":"1ZCnLW6HUl0bXrPa-54-bv_C9f6jYL0T4","timestamp":1584436296801},{"file_id":"1gTLXTd_rOpXmlktZz2yeEW62gY8ety-I","timestamp":1583941948440},{"file_id":"1gC_pmaDD73tD-yNoFGjHEolYfLd_7czL","timestamp":1583593255888},{"file_id":"17pZee2Vp0kCh3W8pfzRYk8asqk35mOfw","timestamp":1583335080677},{"file_id":"1KyYm3JglQpPYnf-aBLLiP-sFgi_A0Og1","timestamp":1583291424450},{"file_id":"1ZJCI2p66noTaLCnVUQJkTR16ig6GAqAx","timestamp":1576151149296}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"C-wdtVN5KUFi","colab_type":"text"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: /~https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"Qt5Yt1vsD163","colab_type":"text"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"zwILBhMkzKp_","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n","\n"," This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.\n"," The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB."]},{"cell_type":"markdown","metadata":{"id":"pcNfrIVpNZC-","colab_type":"text"},"source":["---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"I0aF5U_Y0IFW","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"EBHobPtQ8wx7","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"UphYcwdDS8yO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kRVmtCZB9OQ2","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"QTEFQc6j9RTv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yk96o-_u-27d","colab_type":"text"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"BbYpGlfskzrO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","import datetime\n","import time\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","#!pip install -U scipy==1.2.0\n","#!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch /~https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqCe6m-C_PrH","colab_type":"text"},"source":["#**3. Select your paths and parameters**\n","---"]},{"cell_type":"markdown","metadata":{"id":"w5NmDpJ4xvWE","colab_type":"text"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"PWxNzzgKu9Kb","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,len(source)//10)\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCKcSJxkxi33","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"msrTTcPI1Cav","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"u_YFN6Bd594L","colab_type":"code","cellView":"form","colab":{}},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," \n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(source)>110:\n"," number_of_images = 110\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"heuBzM5JADYf","colab_type":"text"},"source":["#**4. Train the network**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"eLllOs_rA62U","colab_type":"text"},"source":["##**4.1. Start Trainning**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3)."]},{"cell_type":"code","metadata":{"id":"xe3TLu7M-3Dk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","#Overwriting old models and saving them separately if True\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!./scripts/train_model.sh $model_name 0\n","\n","#After training overwrite any existing model in the model_path with the new trained model.\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fpXr4JlCd5uV","colab_type":"text"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"id":"x41OhmO-hsX3","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\n","\n","import shutil\n","import os\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","else:\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QefQX9WUBz0G","colab_type":"text"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"2-0m_-tF9oo-","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we replace values in the old files\n","\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","#model_name = \"\" #@param {type:\"string\"}\n","\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","batch_size = 4 #@param {type:\"number\"}\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#Move your model to fnet\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\n","\n","#Move the datasets into fnet\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","### number_of_images = len(os.listdir(Training_source)) ###\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","\n","# We will use the same validation files from the training dataset as used before,\n","# This makes sure that the model is not validated with files it has seen in training before saving.\n","\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\n","val_source_list = []\n","\n","##CHECK THIS Prediction_model_name\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_folder+'_val.csv', 'r') as f:\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\n"," contents = csv.reader(f,delimiter=',')\n"," for row in contents:\n"," val_source_list.append(row[0])\n","\n","#Get the file list without the header\n","val_source_list = val_source_list[1::]\n","\n","#Get only the file names and not the full path\n","for i in range(0,len(val_source_list)):\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\n","\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_source_list:\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","#Make a training csv file.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\n","\n","#Find the number of previous training iterations (steps) from loss csv file\n","\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 50000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vH3EzxbfD6Uk","colab_type":"code","cellView":"form","colab":{}},"source":["import datetime\n","import time\n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","import os\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\n","\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","if os.path.exists(Pretrained_model_folder):\n"," shutil.rmtree(Pretrained_model_folder)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\n","\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","min, sec = divmod(dt, 60) \n","hour, min = divmod(min, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jwORXPtcqRHZ","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"rVBx2b2MpoFf","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aNR6bAk6oZJD","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ratRdSDlcQ9G","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YkhOGv3Hp2xI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"vqSH6EQb4BwU","colab_type":"code","cellView":"form","colab":{}},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Source_QC_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model == True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V2ghLobACMy6","colab_type":"text"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"SMw0nWXeeC1N","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"8yoXStc8Lo27","colab_type":"code","cellView":"form","colab":{}},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e2f-coEkCf58","colab_type":"text"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"Uzv5rp6LrYQF","colab_type":"code","cellView":"form","colab":{}},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the slice would you like to view?\n","slice_number = 1#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3dP2CrCVee1m","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"IXXOocFl3on8","colab_type":"text"},"source":["## **6.4. Purge unnecessary folders**\n","---\n"]},{"cell_type":"code","metadata":{"id":"emO85anSThPJ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\n","\n","import shutil\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"l52zLRCn3z9v","colab_type":"text"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb old mode 100755 new mode 100644